Make proxy use roundrobin to choose replica (#17063)

Fixes: #17055

Signed-off-by: yangxuan <xuan.yang@zilliz.com>
This commit is contained in:
XuanYang-cn 2022-05-17 22:35:57 +08:00 committed by GitHub
parent b37b87eb97
commit 127dd34b37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 55 deletions

View File

@ -712,7 +712,7 @@ func (c *ChannelManager) Release(nodeID UniqueID, channelName string) error {
toReleaseChannel := c.getChannelByNodeAndName(nodeID, channelName) toReleaseChannel := c.getChannelByNodeAndName(nodeID, channelName)
if toReleaseChannel == nil { if toReleaseChannel == nil {
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName) return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
} }
toReleaseUpdates := getReleaseOp(nodeID, toReleaseChannel) toReleaseUpdates := getReleaseOp(nodeID, toReleaseChannel)
@ -731,7 +731,7 @@ func (c *ChannelManager) toDelete(nodeID UniqueID, channelName string) error {
ch := c.getChannelByNodeAndName(nodeID, channelName) ch := c.getChannelByNodeAndName(nodeID, channelName)
if ch == nil { if ch == nil {
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName) return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
} }
if !c.isMarkedDrop(channelName) { if !c.isMarkedDrop(channelName) {

View File

@ -53,7 +53,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error) GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema. // GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error)
ClearShards(collectionName string) ClearShards(collectionName string)
RemoveCollection(ctx context.Context, collectionName string) RemoveCollection(ctx context.Context, collectionName string)
RemovePartition(ctx context.Context, collectionName string, partitionName string) RemovePartition(ctx context.Context, collectionName string, partitionName string)
@ -70,7 +70,7 @@ type collectionInfo struct {
collID typeutil.UniqueID collID typeutil.UniqueID
schema *schemapb.CollectionSchema schema *schemapb.CollectionSchema
partInfo map[string]*partitionInfo partInfo map[string]*partitionInfo
shardLeaders []*querypb.ShardLeadersList shardLeaders map[string][]queryNode
createdTimestamp uint64 createdTimestamp uint64
createdUtcTimestamp uint64 createdUtcTimestamp uint64
} }
@ -528,7 +528,7 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
} }
// GetShards update cache if withCache == false // GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) { func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error) {
info, err := m.GetCollectionInfo(ctx, collectionName) info, err := m.GetCollectionInfo(ctx, collectionName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -536,7 +536,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
if withCache { if withCache {
if len(info.shardLeaders) > 0 { if len(info.shardLeaders) > 0 {
return info.shardLeaders, nil shards := updateShardsWithRoundRobin(info.shardLeaders)
m.mu.Lock()
m.collInfo[collectionName].shardLeaders = shards
m.mu.Unlock()
return shards, nil
} }
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord", log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
zap.String("collectionName", collectionName)) zap.String("collectionName", collectionName))
@ -557,7 +562,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason) return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
} }
shards := resp.GetShards() shards := parseShardLeaderList2QueryNode(resp.GetShards())
shards = updateShardsWithRoundRobin(shards)
m.mu.Lock() m.mu.Lock()
m.collInfo[collectionName].shardLeaders = shards m.collInfo[collectionName].shardLeaders = shards
@ -566,6 +573,22 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
return shards, nil return shards, nil
} }
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
shard2QueryNodes := make(map[string][]queryNode)
for _, leaders := range shardsLeaders {
qns := make([]queryNode, len(leaders.GetNodeIds()))
for j := range qns {
qns[j] = queryNode{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
}
shard2QueryNodes[leaders.GetChannelName()] = qns
}
return shard2QueryNodes
}
// ClearShards clear the shard leader cache of a collection // ClearShards clear the shard leader cache of a collection
func (m *MetaCache) ClearShards(collectionName string) { func (m *MetaCache) ClearShards(collectionName string) {
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName)) log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))

View File

@ -344,8 +344,8 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards)) assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds())) assert.Equal(t, 3, len(shards["channel-1"]))
// get from cache // get from cache
qc.validShardLeaders = false qc.validShardLeaders = false
@ -353,8 +353,7 @@ func TestMetaCache_GetShards(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, shards) assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards)) assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs())) assert.Equal(t, 3, len(shards["channel-1"]))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
}) })
} }
@ -387,8 +386,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, shards) require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards)) require.Equal(t, 1, len(shards))
require.Equal(t, 3, len(shards[0].GetNodeAddrs())) require.Equal(t, 3, len(shards["channel-1"]))
require.Equal(t, 3, len(shards[0].GetNodeIds()))
globalMetaCache.ClearShards(collectionName) globalMetaCache.ClearShards(collectionName)

View File

@ -3,11 +3,11 @@ package proxy
import ( import (
"context" "context"
"errors" "errors"
"fmt"
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client" qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap" "go.uber.org/zap"
@ -15,7 +15,7 @@ import (
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error) type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error
// TODO add another policy to enbale the use of cache // TODO add another policy to enbale the use of cache
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime // defaultGetQueryNodePolicy creates QueryNode client for every address everytime
@ -40,23 +40,45 @@ var (
errInvalidShardLeaders = errors.New("Invalid shard leader") errInvalidShardLeaders = errors.New("Invalid shard leader")
) )
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error { type queryNode struct {
nodeID UniqueID
address string
}
func (q queryNode) String() string {
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
}
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string][]queryNode {
for channelID, leaders := range shardsLeaders {
if len(leaders) <= 1 {
continue
}
shardsLeaders[channelID] = append(leaders[1:], leaders[0])
}
return shardsLeaders
}
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
var ( var (
err = errBegin err = errBegin
current = 0 current = 0
qn types.QueryNode qn types.QueryNode
) )
replicaNum := len(leaders.GetNodeIds()) replicaNum := len(leaders)
for err != nil && current < replicaNum { for err != nil && current < replicaNum {
currentID := leaders.GetNodeIds()[current] currentID := leaders[current].nodeID
if err != errBegin { if err != errBegin {
log.Warn("retry with another QueryNode", log.Warn("retry with another QueryNode",
zap.Int("retries numbers", current), zap.Int("retries numbers", current),
zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID)) zap.Int64("nodeID", currentID))
} }
qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current]) qn, err = getQueryNodePolicy(ctx, leaders[current].address)
if err != nil { if err != nil {
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID), log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
zap.Error(err)) zap.Error(err))
@ -68,7 +90,6 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
err = query(currentID, qn) err = query(currentID, qn)
if err != nil { if err != nil {
log.Warn("fail to Query with shard leader", log.Warn("fail to Query with shard leader",
zap.String("leader", leaders.GetChannelName()),
zap.Int64("nodeID", currentID), zap.Int64("nodeID", currentID),
zap.Error(err)) zap.Error(err))
} }
@ -76,9 +97,8 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
} }
if current == replicaNum && err != nil { if current == replicaNum && err != nil {
log.Warn("no shard leaders available for channel", log.Warn("no shard leaders available",
zap.String("channel name", leaders.GetChannelName()), zap.String("leaders", fmt.Sprintf("%v", leaders)), zap.Error(err))
zap.Int64s("leaders", leaders.GetNodeIds()), zap.Error(err))
// needs to return the error from query // needs to return the error from query
return err return err
} }

View File

@ -5,11 +5,53 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap"
) )
func TestUpdateShardsWithRoundRobin(t *testing.T) {
in := map[string][]queryNode{
"channel-1": {
{1, "addr1"},
{2, "addr2"},
},
"channel-2": {
{20, "addr20"},
{21, "addr21"},
},
}
out := updateShardsWithRoundRobin(in)
assert.Equal(t, int64(2), out["channel-1"][0].nodeID)
assert.Equal(t, "addr2", out["channel-1"][0].address)
assert.Equal(t, int64(21), out["channel-2"][0].nodeID)
assert.Equal(t, "addr21", out["channel-2"][0].address)
t.Run("check print", func(t *testing.T) {
qns := []queryNode{
{1, "addr1"},
{2, "addr2"},
{20, "addr20"},
{21, "addr21"},
}
res := fmt.Sprintf("list: %v", qns)
log.Debug("Check String func",
zap.Any("Any", qns),
zap.Any("ok", qns[0]),
zap.String("ok2", res),
)
})
}
func TestRoundRobinPolicy(t *testing.T) { func TestRoundRobinPolicy(t *testing.T) {
var ( var (
getQueryNodePolicy = mockGetQueryNodePolicy getQueryNodePolicy = mockGetQueryNodePolicy
@ -31,11 +73,12 @@ func TestRoundRobinPolicy(t *testing.T) {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
query := (&mockQuery{isvalid: false}).query query := (&mockQuery{isvalid: false}).query
leaders := &querypb.ShardLeadersList{ leaders := make([]queryNode, 0, len(test.leaderIDs))
ChannelName: t.Name(), for _, ID := range test.leaderIDs {
NodeIds: test.leaderIDs, leaders = append(leaders, queryNode{ID, "random-addr"})
NodeAddrs: make([]string, len(test.leaderIDs)),
} }
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders) err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.Error(t, err) require.Error(t, err)
}) })
@ -55,10 +98,10 @@ func TestRoundRobinPolicy(t *testing.T) {
for _, test := range allPassTests { for _, test := range allPassTests {
query := (&mockQuery{isvalid: true}).query query := (&mockQuery{isvalid: true}).query
leaders := &querypb.ShardLeadersList{ leaders := make([]queryNode, 0, len(test.leaderIDs))
ChannelName: t.Name(), for _, ID := range test.leaderIDs {
NodeIds: test.leaderIDs, leaders = append(leaders, queryNode{ID, "random-addr"})
NodeAddrs: make([]string, len(test.leaderIDs)),
} }
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders) err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.NoError(t, err) require.NoError(t, err)
@ -77,10 +120,10 @@ func TestRoundRobinPolicy(t *testing.T) {
for _, test := range passAtLast { for _, test := range passAtLast {
query := (&mockQuery{isvalid: true}).query query := (&mockQuery{isvalid: true}).query
leaders := &querypb.ShardLeadersList{ leaders := make([]queryNode, 0, len(test.leaderIDs))
ChannelName: t.Name(), for _, ID := range test.leaderIDs {
NodeIds: test.leaderIDs, leaders = append(leaders, queryNode{ID, "random-addr"})
NodeAddrs: make([]string, len(test.leaderIDs)),
} }
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders) err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
require.NoError(t, err) require.NoError(t, err)

View File

@ -244,16 +244,17 @@ func (t *queryTask) Execute(ctx context.Context) error {
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards)) t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards)) t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx) t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
for _, shard := range shards { for channelID, leaders := range shards {
s := shard channelID := channelID
leaders := leaders
t.runningGroup.Go(func() error { t.runningGroup.Go(func() error {
log.Debug("proxy starting to query one shard", log.Debug("proxy starting to query one shard",
zap.Int64("collectionID", t.CollectionID), zap.Int64("collectionID", t.CollectionID),
zap.String("collection name", t.collectionName), zap.String("collection name", t.collectionName),
zap.String("shard channel", s.GetChannelName()), zap.String("shard channel", channelID),
zap.Uint64("timeoutTs", t.TimeoutTimestamp)) zap.Uint64("timeoutTs", t.TimeoutTimestamp))
err := t.queryShard(t.runningGroupCtx, s) err := t.queryShard(t.runningGroupCtx, leaders, channelID)
if err != nil { if err != nil {
return err return err
} }
@ -344,12 +345,12 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
return nil return nil
} }
func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeadersList) error { func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
query := func(nodeID UniqueID, qn types.QueryNode) error { query := func(nodeID UniqueID, qn types.QueryNode) error {
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
Req: t.RetrieveRequest, Req: t.RetrieveRequest,
IsShardLeader: true, IsShardLeader: true,
DmlChannel: leaders.GetChannelName(), DmlChannel: channelID,
} }
result, err := qn.Query(ctx, req) result, err := qn.Query(ctx, req)
@ -364,14 +365,14 @@ func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeader
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason()) return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
} }
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", leaders.GetChannelName())) log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
t.resultBuf <- result t.resultBuf <- result
return nil return nil
} }
err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders) err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
if err != nil { if err != nil {
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders.GetNodeIds())) log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
return err return err
} }

View File

@ -258,27 +258,28 @@ func (t *searchTask) Execute(ctx context.Context) error {
defer tr.Elapse("done") defer tr.Elapse("done")
executeSearch := func(withCache bool) error { executeSearch := func(withCache bool) error {
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc) shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
if err != nil { if err != nil {
return err return err
} }
t.resultBuf = make(chan *internalpb.SearchResults, len(shards)) t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shards)) t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx) t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
// TODO: try to merge rpc send to different shard leaders. // TODO: try to merge rpc send to different shard leaders.
// If two shard leader is on the same querynode maybe we should merge request to save rpc // If two shard leader is on the same querynode maybe we should merge request to save rpc
for _, shard := range shards { for channelID, leaders := range shard2Leaders {
s := shard channelID := channelID
leaders := leaders
t.runningGroup.Go(func() error { t.runningGroup.Go(func() error {
log.Debug("proxy starting to query one shard", log.Debug("proxy starting to query one shard",
zap.Int64("collectionID", t.CollectionID), zap.Int64("collectionID", t.CollectionID),
zap.String("collection name", t.collectionName), zap.String("collection name", t.collectionName),
zap.String("shard channel", s.GetChannelName()), zap.String("shard channel", channelID),
zap.Uint64("timeoutTs", t.TimeoutTimestamp)) zap.Uint64("timeoutTs", t.TimeoutTimestamp))
err := t.searchShard(t.runningGroupCtx, s) err := t.searchShard(t.runningGroupCtx, leaders, channelID)
if err != nil { if err != nil {
return err return err
} }
@ -393,13 +394,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
return nil return nil
} }
func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLeadersList) error { func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, channelID string) error {
search := func(nodeID UniqueID, qn types.QueryNode) error { search := func(nodeID UniqueID, qn types.QueryNode) error {
req := &querypb.SearchRequest{ req := &querypb.SearchRequest{
Req: t.SearchRequest, Req: t.SearchRequest,
IsShardLeader: true, IsShardLeader: true,
DmlChannel: leaders.GetChannelName(), DmlChannel: channelID,
} }
result, err := qn.Search(ctx, req) result, err := qn.Search(ctx, req)
@ -420,7 +421,7 @@ func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLead
err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders) err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders)
if err != nil { if err != nil {
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders.GetNodeIds())) log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders))
return err return err
} }