mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: Enable dynamic update replica selection policy (#35860)
issue: #35859 --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
c03eb6f664
commit
bd658a6510
@ -61,77 +61,75 @@ type LBPolicy interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
const (
|
||||
RoundRobin = "round_robin"
|
||||
LookAside = "look_aside"
|
||||
)
|
||||
|
||||
type LBPolicyImpl struct {
|
||||
balancer LBBalancer
|
||||
clientMgr shardClientMgr
|
||||
getBalancer func() LBBalancer
|
||||
clientMgr shardClientMgr
|
||||
balancerMap map[string]LBBalancer
|
||||
}
|
||||
|
||||
func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
|
||||
balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue()
|
||||
balancerMap := make(map[string]LBBalancer)
|
||||
balancerMap[LookAside] = NewLookAsideBalancer(clientMgr)
|
||||
balancerMap[RoundRobin] = NewRoundRobinBalancer()
|
||||
|
||||
var balancer LBBalancer
|
||||
switch balancePolicy {
|
||||
case "round_robin":
|
||||
log.Info("use round_robin policy on replica selection")
|
||||
balancer = NewRoundRobinBalancer()
|
||||
default:
|
||||
log.Info("use look_aside policy on replica selection")
|
||||
balancer = NewLookAsideBalancer(clientMgr)
|
||||
getBalancer := func() LBBalancer {
|
||||
balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue()
|
||||
if _, ok := balancerMap[balancePolicy]; !ok {
|
||||
return balancerMap[LookAside]
|
||||
}
|
||||
return balancerMap[balancePolicy]
|
||||
}
|
||||
|
||||
return &LBPolicyImpl{
|
||||
balancer: balancer,
|
||||
clientMgr: clientMgr,
|
||||
getBalancer: getBalancer,
|
||||
clientMgr: clientMgr,
|
||||
balancerMap: balancerMap,
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LBPolicyImpl) Start(ctx context.Context) {
|
||||
lb.balancer.Start(ctx)
|
||||
for _, lb := range lb.balancerMap {
|
||||
lb.Start(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// try to select the best node from the available nodes
|
||||
func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("collectionName", workload.collectionName),
|
||||
zap.String("channelName", workload.channel),
|
||||
)
|
||||
|
||||
filterAvailableNodes := func(node int64, _ int) bool {
|
||||
return !excludeNodes.Contain(node)
|
||||
}
|
||||
|
||||
getShardLeaders := func() ([]int64, error) {
|
||||
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
|
||||
availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) })
|
||||
targetNode, err := balancer.SelectNode(ctx, availableNodes, workload.nq)
|
||||
if err != nil {
|
||||
log := log.Ctx(ctx)
|
||||
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
|
||||
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }), nil
|
||||
}
|
||||
|
||||
availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes)
|
||||
targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
|
||||
if err != nil {
|
||||
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
|
||||
nodes, err := getShardLeaders()
|
||||
if err != nil || len(nodes) == 0 {
|
||||
log.Warn("failed to get shard delegator",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Error(err))
|
||||
return -1, err
|
||||
}
|
||||
|
||||
availableNodes := lo.Filter(nodes, filterAvailableNodes)
|
||||
availableNodes := lo.FilterMap(shardLeaders[workload.channel], func(node nodeInfo, _ int) (int64, bool) { return node.nodeID, !excludeNodes.Contain(node.nodeID) })
|
||||
if len(availableNodes) == 0 {
|
||||
nodes := lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID })
|
||||
log.Warn("no available shard delegator found",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64s("nodes", nodes),
|
||||
zap.Int64s("excluded", excludeNodes.Collect()))
|
||||
return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found")
|
||||
}
|
||||
|
||||
targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
|
||||
targetNode, err = balancer.SelectNode(ctx, availableNodes, workload.nq)
|
||||
if err != nil {
|
||||
log.Warn("failed to select shard",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64s("availableNodes", availableNodes),
|
||||
zap.Error(err))
|
||||
return -1, err
|
||||
@ -144,17 +142,15 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
|
||||
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
|
||||
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
|
||||
excludeNodes := typeutil.NewUniqueSet()
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("collectionName", workload.collectionName),
|
||||
zap.String("channelName", workload.channel),
|
||||
)
|
||||
|
||||
var lastErr error
|
||||
err := retry.Do(ctx, func() error {
|
||||
targetNode, err := lb.selectNode(ctx, workload, excludeNodes)
|
||||
balancer := lb.getBalancer()
|
||||
targetNode, err := lb.selectNode(ctx, balancer, workload, excludeNodes)
|
||||
if err != nil {
|
||||
log.Warn("failed to select node for shard",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64("nodeID", targetNode),
|
||||
zap.Error(err),
|
||||
)
|
||||
@ -163,16 +159,18 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
}
|
||||
return err
|
||||
}
|
||||
// cancel work load which assign to the target node
|
||||
defer balancer.CancelWorkload(targetNode, workload.nq)
|
||||
|
||||
client, err := lb.clientMgr.GetClient(ctx, targetNode)
|
||||
if err != nil {
|
||||
log.Warn("search/query channel failed, node not available",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64("nodeID", targetNode),
|
||||
zap.Error(err))
|
||||
excludeNodes.Insert(targetNode)
|
||||
|
||||
// cancel work load which assign to the target node
|
||||
lb.balancer.CancelWorkload(targetNode, workload.nq)
|
||||
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel)
|
||||
return lastErr
|
||||
}
|
||||
@ -180,16 +178,15 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
err = workload.exec(ctx, targetNode, client, workload.channel)
|
||||
if err != nil {
|
||||
log.Warn("search/query channel failed",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64("nodeID", targetNode),
|
||||
zap.Error(err))
|
||||
excludeNodes.Insert(targetNode)
|
||||
lb.balancer.CancelWorkload(targetNode, workload.nq)
|
||||
|
||||
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
lb.balancer.CancelWorkload(targetNode, workload.nq)
|
||||
return nil
|
||||
}, retry.Attempts(workload.retryTimes))
|
||||
|
||||
@ -232,9 +229,11 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
|
||||
}
|
||||
|
||||
func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
|
||||
lb.balancer.UpdateCostMetrics(node, cost)
|
||||
lb.getBalancer().UpdateCostMetrics(node, cost)
|
||||
}
|
||||
|
||||
func (lb *LBPolicyImpl) Close() {
|
||||
lb.balancer.Close()
|
||||
for _, lb := range lb.balancerMap {
|
||||
lb.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@ -101,7 +101,9 @@ func (s *LBPolicySuite) SetupTest() {
|
||||
s.lbBalancer.EXPECT().Start(context.Background()).Maybe()
|
||||
s.lbPolicy = NewLBPolicyImpl(s.mgr)
|
||||
s.lbPolicy.Start(context.Background())
|
||||
s.lbPolicy.balancer = s.lbBalancer
|
||||
s.lbPolicy.getBalancer = func() LBBalancer {
|
||||
return s.lbBalancer
|
||||
}
|
||||
|
||||
err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
|
||||
s.NoError(err)
|
||||
@ -163,7 +165,7 @@ func (s *LBPolicySuite) loadCollection() {
|
||||
func (s *LBPolicySuite) TestSelectNode() {
|
||||
ctx := context.Background()
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -178,7 +180,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -192,7 +194,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
// test select node always fails, expected failure
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -206,7 +208,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
// test all nodes has been excluded, expected failure
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -222,7 +224,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
s.qc.ExpectedCalls = nil
|
||||
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -419,17 +421,17 @@ func (s *LBPolicySuite) TestUpdateCostMetrics() {
|
||||
|
||||
func (s *LBPolicySuite) TestNewLBPolicy() {
|
||||
policy := NewLBPolicyImpl(s.mgr)
|
||||
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer")
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
|
||||
policy.Close()
|
||||
|
||||
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin")
|
||||
policy = NewLBPolicyImpl(s.mgr)
|
||||
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.RoundRobinBalancer")
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.RoundRobinBalancer")
|
||||
policy.Close()
|
||||
|
||||
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside")
|
||||
policy = NewLBPolicyImpl(s.mgr)
|
||||
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer")
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
|
||||
policy.Close()
|
||||
}
|
||||
|
||||
|
||||
@ -952,11 +952,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
||||
zap.String("collectionName", collectionName),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
|
||||
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName)
|
||||
if withCache {
|
||||
if ok {
|
||||
@ -968,6 +963,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
|
||||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
|
||||
}
|
||||
|
||||
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &querypb.GetShardLeadersRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
|
||||
|
||||
@ -22,18 +22,14 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type RoundRobinBalancer struct {
|
||||
// request num send to each node
|
||||
nodeWorkload *typeutil.ConcurrentMap[int64, *atomic.Int64]
|
||||
idx atomic.Int64
|
||||
}
|
||||
|
||||
func NewRoundRobinBalancer() *RoundRobinBalancer {
|
||||
return &RoundRobinBalancer{
|
||||
nodeWorkload: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
|
||||
}
|
||||
return &RoundRobinBalancer{}
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) {
|
||||
@ -41,32 +37,11 @@ func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []in
|
||||
return -1, merr.ErrNodeNotAvailable
|
||||
}
|
||||
|
||||
targetNode := int64(-1)
|
||||
var targetNodeWorkload *atomic.Int64
|
||||
for _, node := range availableNodes {
|
||||
workload, ok := b.nodeWorkload.Get(node)
|
||||
|
||||
if !ok {
|
||||
workload = atomic.NewInt64(0)
|
||||
b.nodeWorkload.Insert(node, workload)
|
||||
}
|
||||
|
||||
if targetNodeWorkload == nil || workload.Load() < targetNodeWorkload.Load() {
|
||||
targetNode = node
|
||||
targetNodeWorkload = workload
|
||||
}
|
||||
}
|
||||
|
||||
targetNodeWorkload.Add(cost)
|
||||
return targetNode, nil
|
||||
idx := b.idx.Inc()
|
||||
return availableNodes[int(idx)%len(availableNodes)], nil
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) {
|
||||
load, ok := b.nodeWorkload.Get(node)
|
||||
|
||||
if ok {
|
||||
load.Sub(nq)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {}
|
||||
|
||||
@ -20,6 +20,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type RoundRobinBalancerSuite struct {
|
||||
@ -33,48 +35,34 @@ func (s *RoundRobinBalancerSuite) SetupTest() {
|
||||
s.balancer.Start(context.Background())
|
||||
}
|
||||
|
||||
func (s *RoundRobinBalancerSuite) TestRoundRobin() {
|
||||
availableNodes := []int64{1, 2}
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
func TestSelectNode(t *testing.T) {
|
||||
balancer := NewRoundRobinBalancer()
|
||||
|
||||
workload, ok := s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(2), workload.Load())
|
||||
workload, ok = s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(2), workload.Load())
|
||||
// Test case 1: Empty availableNodes
|
||||
_, err1 := balancer.SelectNode(context.Background(), []int64{}, 0)
|
||||
if err1 != merr.ErrNodeNotAvailable {
|
||||
t.Errorf("Expected ErrNodeNotAvailable, got %v", err1)
|
||||
}
|
||||
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 3)
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
// Test case 2: Non-empty availableNodes
|
||||
availableNodes := []int64{1, 2, 3}
|
||||
selectedNode2, err2 := balancer.SelectNode(context.Background(), availableNodes, 0)
|
||||
if err2 != nil {
|
||||
t.Errorf("Expected no error, got %v", err2)
|
||||
}
|
||||
if selectedNode2 < 1 || selectedNode2 > 3 {
|
||||
t.Errorf("Expected a node in the range [1, 3], got %d", selectedNode2)
|
||||
}
|
||||
|
||||
workload, ok = s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(5), workload.Load())
|
||||
workload, ok = s.balancer.nodeWorkload.Get(1)
|
||||
s.True(ok)
|
||||
s.Equal(int64(5), workload.Load())
|
||||
}
|
||||
|
||||
func (s *RoundRobinBalancerSuite) TestNoAvailableNode() {
|
||||
availableNodes := []int64{}
|
||||
_, err := s.balancer.SelectNode(context.TODO(), availableNodes, 1)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *RoundRobinBalancerSuite) TestCancelWorkload() {
|
||||
availableNodes := []int64{101}
|
||||
_, err := s.balancer.SelectNode(context.TODO(), availableNodes, 5)
|
||||
s.NoError(err)
|
||||
workload, ok := s.balancer.nodeWorkload.Get(101)
|
||||
s.True(ok)
|
||||
s.Equal(int64(5), workload.Load())
|
||||
s.balancer.CancelWorkload(101, 5)
|
||||
s.Equal(int64(0), workload.Load())
|
||||
// Test case 3: Boundary case
|
||||
availableNodes = []int64{1}
|
||||
selectedNode3, err3 := balancer.SelectNode(context.Background(), availableNodes, 0)
|
||||
if err3 != nil {
|
||||
t.Errorf("Expected no error, got %v", err3)
|
||||
}
|
||||
if selectedNode3 != 1 {
|
||||
t.Errorf("Expected 1, got %d", selectedNode3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinBalancerSuite(t *testing.T) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user