feat: Implement partial result support on node down (#42009)

issue: https://github.com/milvus-io/milvus/issues/41690
This commit implements partial search result functionality when query
nodes go down, improving system availability during node failures. The
changes include:

- Enhanced load balancing in proxy (lb_policy.go) to handle node
failures with retry support
- Added partial search result capability in querynode delegator and
distribution logic
- Implemented tests for various partial result scenarios when nodes go
down
- Added metrics to track partial search results in querynode_metrics.go
- Updated parameter configuration to support partial result required
data ratio
- Replaced old partial_search_test.go with more comprehensive
partial_result_on_node_down_test.go
- Updated proto definitions and improved retry logic

These changes improve query resilience by returning partial results to
users when some query nodes are unavailable, ensuring that queries don't
completely fail when a portion of data remains accessible.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-05-28 00:12:28 +08:00 committed by GitHub
parent 57b58ad778
commit 54619eaa2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 3447 additions and 2625 deletions

View File

@ -329,6 +329,7 @@ proxy:
slowQuerySpanInSeconds: 5 # query whose executed time exceeds the `slowQuerySpanInSeconds` can be considered slow, in seconds.
queryNodePooling:
size: 10 # the size for shardleader(querynode) client pool
partialResultRequiredDataRatio: 1 # partial result required data ratio, default to 1 which means disable partial result, otherwise, it will be used as the minimum data ratio for partial result
http:
enabled: true # Whether to enable the http server
debug_mode: false # Whether to enable http server debug mode

View File

@ -17,6 +17,7 @@ package proxy
import (
"context"
"strings"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
@ -119,22 +120,70 @@ func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, coll
}
// try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) {
filterDelegator := func(nodes []nodeInfo) map[int64]nodeInfo {
ret := make(map[int64]nodeInfo)
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload *ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) {
log := log.Ctx(ctx)
// Select node using specified nodes
trySelectNode := func(nodes []nodeInfo) (nodeInfo, error) {
candidateNodes := make(map[int64]nodeInfo)
serviceableNodes := make(map[int64]nodeInfo)
// Filter nodes based on excludeNodes
for _, node := range nodes {
if !excludeNodes.Contain(node.nodeID) {
ret[node.nodeID] = node
if node.serviceable {
serviceableNodes[node.nodeID] = node
}
candidateNodes[node.nodeID] = node
}
}
return ret
var err error
defer func() {
if err != nil {
candidatesInStr := lo.Map(nodes, func(node nodeInfo, _ int) string {
return node.String()
})
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node nodeInfo, _ int) string {
return node.String()
})
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.String("candidates", strings.Join(candidatesInStr, ", ")),
zap.String("serviceableNodes", strings.Join(serviceableNodesInStr, ", ")),
zap.Error(err))
}
}()
if len(candidateNodes) == 0 {
err = merr.WrapErrChannelNotAvailable(workload.channel)
return nodeInfo{}, err
}
balancer.RegisterNodeInfo(lo.Values(candidateNodes))
// prefer serviceable nodes
var targetNodeID int64
if len(serviceableNodes) > 0 {
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(serviceableNodes), workload.nq)
} else {
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(candidateNodes), workload.nq)
}
if err != nil {
return nodeInfo{}, err
}
if _, ok := candidateNodes[targetNodeID]; !ok {
err = merr.WrapErrNodeNotAvailable(targetNodeID)
return nodeInfo{}, err
}
return candidateNodes[targetNodeID], nil
}
availableNodes := filterDelegator(workload.shardLeaders)
balancer.RegisterNodeInfo(lo.Values(availableNodes))
targetNode, err := balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq)
// First attempt with current shard leaders
targetNode, err := trySelectNode(workload.shardLeaders)
// If failed, refresh cache and retry
if err != nil {
log := log.Ctx(ctx)
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
shardLeaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, false)
if err != nil {
@ -145,51 +194,41 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
return nodeInfo{}, err
}
availableNodes = filterDelegator(shardLeaders[workload.channel])
if len(availableNodes) == 0 {
log.Warn("no available shard delegator found",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
zap.Int64s("excluded", excludeNodes.Collect()))
return nodeInfo{}, merr.WrapErrChannelNotAvailable("no available shard delegator found")
}
balancer.RegisterNodeInfo(lo.Values(availableNodes))
targetNode, err = balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq)
workload.shardLeaders = shardLeaders[workload.channel]
// Second attempt with fresh shard leaders
targetNode, err = trySelectNode(workload.shardLeaders)
if err != nil {
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.Error(err))
return nodeInfo{}, err
}
}
return availableNodes[targetNode], nil
return targetNode, nil
}
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
excludeNodes := typeutil.NewUniqueSet()
var lastErr error
err := retry.Do(ctx, func() error {
excludeNodes := typeutil.NewUniqueSet()
tryExecute := func() (bool, error) {
// if keeping retry after all nodes are excluded, try to clean excludeNodes
if excludeNodes.Len() == len(workload.shardLeaders) {
excludeNodes.Clear()
}
balancer := lb.getBalancer()
targetNode, err := lb.selectNode(ctx, balancer, workload, excludeNodes)
targetNode, err := lb.selectNode(ctx, balancer, &workload, excludeNodes)
if err != nil {
log.Warn("failed to select node for shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode.nodeID),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.Error(err),
)
if lastErr != nil {
return lastErr
return true, lastErr
}
return err
return true, err
}
// cancel work load which assign to the target node
defer balancer.CancelWorkload(targetNode.nodeID, workload.nq)
@ -204,7 +243,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
excludeNodes.Insert(targetNode.nodeID)
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel)
return lastErr
return true, lastErr
}
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
@ -216,11 +255,19 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
zap.Error(err))
excludeNodes.Insert(targetNode.nodeID)
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.nodeID, workload.channel)
return lastErr
return true, lastErr
}
return nil
}, retry.Attempts(workload.retryTimes))
return true, nil
}
// if failed, try to execute with partial result
err := retry.Handle(ctx, tryExecute, retry.Attempts(workload.retryTimes))
if err != nil {
log.Ctx(ctx).Warn("failed to execute with partial result",
zap.String("channel", workload.channel),
zap.Error(err))
}
return err
}
@ -233,8 +280,14 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
return err
}
// let every request could retry at least twice, which could retry after update shard leader cache
wg, ctx := errgroup.WithContext(ctx)
totalChannels := len(dml2leaders)
if totalChannels == 0 {
log.Ctx(ctx).Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID))
return merr.WrapErrCollectionNotLoaded(workload.collectionID)
}
wg, _ := errgroup.WithContext(ctx)
// Launch a goroutine for each channel
for k, v := range dml2leaders {
channel := k
nodes := v

View File

@ -69,8 +69,9 @@ func (s *LBPolicySuite) SetupTest() {
for i := 1; i <= 5; i++ {
s.nodeIDs = append(s.nodeIDs, int64(i))
s.nodes = append(s.nodes, nodeInfo{
nodeID: int64(i),
address: "localhost",
nodeID: int64(i),
address: "localhost",
serviceable: true,
})
}
s.channels = []string{"channel1", "channel2"}
@ -84,11 +85,13 @@ func (s *LBPolicySuite) SetupTest() {
ChannelName: s.channels[0],
NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
Serviceable: []bool{true, true, true, true, true},
},
{
ChannelName: s.channels[1],
NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
Serviceable: []bool{true, true, true, true, true},
},
},
}, nil
@ -175,7 +178,7 @@ func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -191,12 +194,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: []nodeInfo{},
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet())
s.NoError(err)
@ -206,7 +209,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -220,7 +223,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -237,7 +240,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return nil, merr.ErrServiceUnavailable
}
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -291,7 +294,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
// test get client failed, and retry failed, expected success
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(2)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
@ -313,6 +316,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
@ -334,7 +341,9 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
@ -384,7 +393,9 @@ func (s *LBPolicySuite) TestExecute() {
// test all channel success
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,

View File

@ -970,7 +970,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
}
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
}
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
@ -983,7 +982,8 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionID: info.collID,
CollectionID: info.collID,
WithUnserviceableShards: true,
}
tr := timerecord.NewTimeRecorder("UpdateShardCache")
@ -1002,6 +1002,19 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
idx: atomic.NewInt64(0),
}
// convert shards map to string for logging
if log.Logger.Level() == zap.DebugLevel {
shardStr := make([]string, 0, len(shards))
for channel, nodes := range shards {
nodeStrs := make([]string, 0, len(nodes))
for _, node := range nodes {
nodeStrs = append(nodeStrs, node.String())
}
shardStr = append(shardStr, fmt.Sprintf("%s:[%s]", channel, strings.Join(nodeStrs, ", ")))
}
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
}
m.leaderMut.Lock()
if _, ok := m.collLeader[database]; !ok {
m.collLeader[database] = make(map[string]*shardLeaders)
@ -1028,7 +1041,7 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
qns := make([]nodeInfo, len(leaders.GetNodeIds()))
for j := range qns {
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j], leaders.GetServiceable()[j]}
}
shard2QueryNodes[leaders.GetChannelName()] = qns

View File

@ -1395,6 +1395,7 @@ func TestMetaCache_GetShards(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil
@ -1455,6 +1456,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil
@ -1836,6 +1838,7 @@ func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil

View File

@ -1505,6 +1505,7 @@ func (coord *MixCoordMock) GetShardLeaders(ctx context.Context, in *querypb.GetS
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil

View File

@ -20,12 +20,13 @@ import (
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)
type nodeInfo struct {
nodeID UniqueID
address string
nodeID UniqueID
address string
serviceable bool
}
func (n nodeInfo) String() string {
return fmt.Sprintf("<NodeID: %d>", n.nodeID)
return fmt.Sprintf("<NodeID: %d, serviceable: %v, address: %s>", n.nodeID, n.serviceable, n.address)
}
var errClosed = errors.New("client is closed")

View File

@ -161,6 +161,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -185,6 +186,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -210,6 +212,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -241,6 +244,7 @@ func getMockQueryCoord() *mocks.MockMixCoordClient {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)

View File

@ -2994,6 +2994,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil

View File

@ -65,6 +65,7 @@ func (s *StatisticTaskSuite) SetupTest() {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil

View File

@ -978,6 +978,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -1002,6 +1003,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -1026,6 +1028,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -1057,6 +1060,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -1082,6 +1086,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)
@ -1107,6 +1112,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil)

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
@ -264,6 +265,12 @@ func (ob *TargetObserver) check(ctx context.Context, collectionID int64) {
if ob.shouldUpdateNextTarget(ctx, collectionID) {
// update next target in collection level
ob.updateNextTarget(ctx, collectionID)
// sync next target to delegator if current target not exist, to support partial search
if !ob.targetMgr.IsCurrentTargetExist(ctx, collectionID, -1) {
newVersion := ob.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.NextTarget)
ob.syncNextTargetToDelegator(ctx, collectionID, ob.distMgr.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collectionID)), newVersion)
}
}
}
@ -412,14 +419,18 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
collReadyDelegatorList = append(collReadyDelegatorList, chReadyDelegatorList...)
}
return ob.syncNextTargetToDelegator(ctx, collectionID, collReadyDelegatorList, newVersion)
}
// sync next target info to delegator as readable snapshot
// 1. if next target is changed before delegator becomes serviceable, we need to sync the new next target to delegator to support partial search
// 2. if next target is ready to read, we need to sync the next target to delegator to support full search
func (ob *TargetObserver) syncNextTargetToDelegator(ctx context.Context, collectionID int64, collReadyDelegatorList []*meta.DmChannel, newVersion int64) bool {
var partitions []int64
var indexInfo []*indexpb.IndexInfo
var err error
for _, d := range collReadyDelegatorList {
updateVersionAction := ob.checkNeedUpdateTargetVersion(ctx, d.View, newVersion)
if updateVersionAction == nil {
continue
}
updateVersionAction := ob.genSyncAction(ctx, d.View, newVersion)
replica := ob.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, d.Node)
if replica == nil {
log.Warn("replica not found", zap.Int64("nodeID", d.Node), zap.Int64("collectionID", collectionID))
@ -441,19 +452,16 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
}
}
if !ob.sync(ctx, replica, d.View, []*querypb.SyncAction{updateVersionAction}, partitions, indexInfo) {
if !ob.syncToDelegator(ctx, replica, d.View, updateVersionAction, partitions, indexInfo) {
return false
}
}
return true
}
func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, LeaderView *meta.LeaderView, diffs []*querypb.SyncAction,
func (ob *TargetObserver) syncToDelegator(ctx context.Context, replica *meta.Replica, LeaderView *meta.LeaderView, action *querypb.SyncAction,
partitions []int64, indexInfo []*indexpb.IndexInfo,
) bool {
if len(diffs) == 0 {
return true
}
replicaID := replica.GetID()
log := log.With(
@ -469,7 +477,7 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, Leade
CollectionID: LeaderView.CollectionID,
ReplicaID: replicaID,
Channel: LeaderView.Channel,
Actions: diffs,
Actions: []*querypb.SyncAction{action},
LoadMeta: &querypb.LoadMetaInfo{
LoadType: ob.meta.GetLoadType(ctx, LeaderView.CollectionID),
CollectionID: LeaderView.CollectionID,
@ -496,31 +504,34 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, Leade
return true
}
func (ob *TargetObserver) checkNeedUpdateTargetVersion(ctx context.Context, leaderView *meta.LeaderView, targetVersion int64) *querypb.SyncAction {
log.Ctx(ctx).WithRateGroup("qcv2.LeaderObserver", 1, 60)
if targetVersion <= leaderView.TargetVersion {
return nil
}
log.RatedInfo(10, "Update readable segment version",
zap.Int64("collectionID", leaderView.CollectionID),
zap.String("channelName", leaderView.Channel),
zap.Int64("nodeID", leaderView.ID),
zap.Int64("oldVersion", leaderView.TargetVersion),
zap.Int64("newVersion", targetVersion),
)
// sync next target info to delegator
// 1. if next target is changed before delegator becomes serviceable, we need to sync the new next target to delegator to support partial search
// 2. if next target is ready to read, we need to sync the next target to delegator to support full search
func (ob *TargetObserver) genSyncAction(ctx context.Context, leaderView *meta.LeaderView, targetVersion int64) *querypb.SyncAction {
log.Ctx(ctx).WithRateGroup("qcv2.LeaderObserver", 1, 60).
RatedInfo(10, "Update readable segment version",
zap.Int64("collectionID", leaderView.CollectionID),
zap.String("channelName", leaderView.Channel),
zap.Int64("nodeID", leaderView.ID),
zap.Int64("oldVersion", leaderView.TargetVersion),
zap.Int64("newVersion", targetVersion),
)
sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
channel := ob.targetMgr.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst)
sealedSegmentRowCount := lo.MapValues(sealedSegments, func(segment *datapb.SegmentInfo, _ int64) int64 {
return segment.GetNumOfRows()
})
action := &querypb.SyncAction{
Type: querypb.SyncType_UpdateVersion,
GrowingInTarget: growingSegments.Collect(),
SealedInTarget: lo.Keys(sealedSegments),
DroppedInTarget: droppedSegments,
TargetVersion: targetVersion,
Type: querypb.SyncType_UpdateVersion,
GrowingInTarget: growingSegments.Collect(),
SealedInTarget: lo.Keys(sealedSegmentRowCount),
DroppedInTarget: droppedSegments,
TargetVersion: targetVersion,
SealedSegmentRowCount: sealedSegmentRowCount,
}
if channel != nil {

View File

@ -263,7 +263,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
}, 7*time.Second, 1*time.Second)
ch1View := suite.distMgr.ChannelDistManager.GetByFilter(meta.WithChannelName2Channel("channel-1"))[0].View
action := suite.observer.checkNeedUpdateTargetVersion(ctx, ch1View, 100)
action := suite.observer.genSyncAction(ctx, ch1View, 100)
suite.Equal(action.GetDeleteCP().Timestamp, uint64(200))
}

View File

@ -902,7 +902,7 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade
}, nil
}
leaders, err := utils.GetShardLeaders(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID())
leaders, err := utils.GetShardLeaders(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID(), req.GetWithUnserviceableShards())
return &querypb.GetShardLeadersResponse{
Status: merr.Status(err),
Shards: leaders,

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
@ -384,6 +385,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
return merr.WrapErrServiceInternal(fmt.Sprintf("failed to get partitions for collection=%d", task.CollectionID()))
}
version := ex.targetMgr.GetCollectionTargetVersion(ctx, task.CollectionID(), meta.NextTargetFirst)
req := packSubChannelRequest(
task,
action,
@ -392,6 +394,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
dmChannel,
indexInfo,
partitions,
version,
)
err = fillSubChannelRequest(ctx, req, ex.broker, ex.shouldIncludeFlushedSegmentInfo(action.Node()))
if err != nil {
@ -400,6 +403,12 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
return err
}
sealedSegments := ex.targetMgr.GetSealedSegmentsByChannel(ctx, dmChannel.CollectionID, dmChannel.ChannelName, meta.NextTarget)
sealedSegmentRowCount := lo.MapValues(sealedSegments, func(segment *datapb.SegmentInfo, _ int64) int64 {
return segment.GetNumOfRows()
})
req.SealedSegmentRowCount = sealedSegmentRowCount
ts := dmChannel.GetSeekPosition().GetTimestamp()
log.Info("subscribe channel...",
zap.Uint64("checkpoint", ts),

View File

@ -208,6 +208,7 @@ func packSubChannelRequest(
channel *meta.DmChannel,
indexInfo []*indexpb.IndexInfo,
partitions []int64,
targetVersion int64,
) *querypb.WatchDmChannelsRequest {
return &querypb.WatchDmChannelsRequest{
Base: commonpbutil.NewMsgBase(
@ -223,6 +224,7 @@ func packSubChannelRequest(
ReplicaID: task.ReplicaID(),
Version: time.Now().UnixNano(),
IndexInfoList: indexInfo,
TargetVersion: targetVersion,
}
}
@ -253,6 +255,7 @@ func fillSubChannelRequest(
req.SegmentInfos = lo.SliceToMap(segmentInfos, func(info *datapb.SegmentInfo) (int64, *datapb.SegmentInfo) {
return info.GetID(), info
})
return nil
}

View File

@ -100,8 +100,14 @@ func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) erro
return nil
}
func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager,
nodeMgr *session.NodeManager, collectionID int64, channels map[string]*meta.DmChannel,
func GetShardLeadersWithChannels(
ctx context.Context,
m *meta.Meta,
dist *meta.DistributionManager,
nodeMgr *session.NodeManager,
collectionID int64,
channels map[string]*meta.DmChannel,
withUnserviceableShards bool,
) ([]*querypb.ShardLeadersList, error) {
ret := make([]*querypb.ShardLeadersList, 0)
@ -111,9 +117,10 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
ids := make([]int64, 0, len(replicas))
addrs := make([]string, 0, len(replicas))
serviceable := make([]bool, 0, len(replicas))
for _, replica := range replicas {
leader := dist.ChannelDistManager.GetShardLeader(channel.GetChannelName(), replica)
if leader == nil || !leader.IsServiceable() {
if leader == nil || (!withUnserviceableShards && !leader.IsServiceable()) {
log.WithRateGroup("util.GetShardLeaders", 1, 60).
Warn("leader is not available in replica", zap.String("channel", channel.GetChannelName()), zap.Int64("replicaID", replica.GetID()))
continue
@ -122,11 +129,11 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
if info != nil {
ids = append(ids, info.ID())
addrs = append(addrs, info.Addr())
serviceable = append(serviceable, leader.IsServiceable())
}
}
// to avoid node down during GetShardLeaders
if len(ids) == 0 {
if len(ids) == 0 && !withUnserviceableShards {
err := merr.WrapErrChannelNotAvailable(channel.GetChannelName())
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
log.Warn(msg, zap.Error(err))
@ -137,13 +144,22 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
ChannelName: channel.GetChannelName(),
NodeIds: ids,
NodeAddrs: addrs,
Serviceable: serviceable,
})
}
return ret, nil
}
func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64) ([]*querypb.ShardLeadersList, error) {
func GetShardLeaders(ctx context.Context,
m *meta.Meta,
targetMgr meta.TargetManagerInterface,
dist *meta.DistributionManager,
nodeMgr *session.NodeManager,
collectionID int64,
withUnserviceableShards bool,
) ([]*querypb.ShardLeadersList, error) {
// skip check load status if withUnserviceableShards is true
if err := checkLoadStatus(ctx, m, collectionID); err != nil {
return nil, err
}
@ -155,7 +171,7 @@ func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetMan
log.Ctx(ctx).Warn("failed to get channels", zap.Error(err))
return nil, err
}
return GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels)
return GetShardLeadersWithChannels(ctx, m, dist, nodeMgr, collectionID, channels, withUnserviceableShards)
}
// CheckCollectionsQueryable check all channels are watched and all segments are loaded for this collection
@ -194,7 +210,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.
return err
}
shardList, err := GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels)
shardList, err := GetShardLeadersWithChannels(ctx, m, dist, nodeMgr, collectionID, channels, false)
if err != nil {
return err
}

View File

@ -30,12 +30,12 @@ import (
"go.opentelemetry.io/otel"
"go.uber.org/atomic"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
@ -88,8 +88,8 @@ type ShardDelegator interface {
LoadL0(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error
LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error
ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error
SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition)
GetQueryView() *channelQueryView
SyncTargetVersion(action *querypb.SyncAction, partitions []int64)
GetChannelQueryView() *channelQueryView
GetDeleteBufferSize() (entryNum int64, memorySize int64)
// manage exclude segments
@ -369,21 +369,32 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
req.Req.GetIsIterator(),
)
partialResultRequiredDataRatio := paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
var tSafe uint64
var err error
if partialResultRequiredDataRatio >= 1.0 {
tSafe, err = sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
} else {
tSafe = sd.GetTSafe()
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
}
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
sealed, growing, version, err := sd.distribution.PinReadableSegments(partialResultRequiredDataRatio, req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to search, current distribution is not serviceable", zap.Error(err))
return nil, err
@ -500,7 +511,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
sealed, growing, version, err := sd.distribution.PinReadableSegments(float64(1.0), req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
return err
@ -562,21 +573,31 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
req.Req.GetIsIterator(),
)
partialResultRequiredDataRatio := paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
// wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
if err != nil {
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
var tSafe uint64
var err error
if partialResultRequiredDataRatio >= 1.0 {
tSafe, err = sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
if err != nil {
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
} else {
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = sd.GetTSafe()
}
}
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
sealed, growing, version, err := sd.distribution.PinReadableSegments(partialResultRequiredDataRatio, req.GetReq().GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
return nil, err
@ -646,7 +667,7 @@ func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetSta
return nil, err
}
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.Req.GetPartitionIDs()...)
sealed, growing, version, err := sd.distribution.PinReadableSegments(1.0, req.Req.GetPartitionIDs()...)
if err != nil {
log.Warn("delegator failed to GetStatistics, current distribution is not servicable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not serviceable")
@ -708,13 +729,14 @@ func organizeSubTask[T any](ctx context.Context,
// update request
req := modify(req, scope, segmentIDs, workerID)
// for partial search, tolerate some worker are offline
worker, err := sd.workerManager.GetWorker(ctx, workerID)
if err != nil {
log.Warn("failed to get worker",
log.Warn("failed to get worker for sub task",
zap.Int64("nodeID", workerID),
zap.Int64s("segments", segmentIDs),
zap.Error(err),
)
return fmt.Errorf("failed to get worker %d, %w", workerID, err)
}
result = append(result, subTask[T]{
@ -744,50 +766,110 @@ func executeSubTasks[T any, R interface {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var wg sync.WaitGroup
wg.Add(len(tasks))
var partialResultRequiredDataRatio float64
if taskType == "Query" || taskType == "Search" {
partialResultRequiredDataRatio = paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
} else {
partialResultRequiredDataRatio = 1.0
}
resultCh := make(chan R, len(tasks))
errCh := make(chan error, 1)
wg, ctx := errgroup.WithContext(ctx)
type channelResult struct {
nodeID int64
segments []int64
result R
err error
}
// Buffered channel to collect results from all goroutines
resultCh := make(chan channelResult, len(tasks))
for _, task := range tasks {
go func(task subTask[T]) {
defer wg.Done()
result, err := execute(ctx, task.req, task.worker)
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason())
task := task // capture loop variable
wg.Go(func() error {
var result R
var err error
if task.targetID == -1 || task.worker == nil {
var segments []int64
if req, ok := any(task.req).(interface{ GetSegmentIDs() []int64 }); ok {
segments = req.GetSegmentIDs()
} else {
segments = []int64{}
}
err = fmt.Errorf("segments not loaded in any worker: %v", segments[:min(len(segments), 10)])
} else {
result, err = execute(ctx, task.req, task.worker)
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason())
}
}
if err != nil {
log.Warn("failed to execute sub task",
zap.String("taskType", taskType),
zap.Int64("nodeID", task.targetID),
zap.Error(err),
)
select {
case errCh <- err: // must be the first
default: // skip other errors
// check if partial result is disabled, if so, let all sub tasks fail fast
if partialResultRequiredDataRatio == 1 {
return err
}
cancel()
return
}
resultCh <- result
}(task)
taskResult := channelResult{
nodeID: task.targetID,
result: result,
err: err,
}
if req, ok := any(task.req).(interface{ GetSegmentIDs() []int64 }); ok {
taskResult.segments = req.GetSegmentIDs()
}
resultCh <- taskResult
return nil
})
}
wg.Wait()
close(resultCh)
select {
case err := <-errCh:
log.Warn("Delegator execute subTask failed",
// Wait for all tasks to complete
if err := wg.Wait(); err != nil {
log.Warn("some tasks failed to complete",
zap.String("taskType", taskType),
zap.Error(err),
)
return nil, err
default:
}
close(resultCh)
successSegmentList := []int64{}
failureSegmentList := []int64{}
var errors []error
// Collect results
results := make([]R, 0, len(tasks))
for item := range resultCh {
if item.err == nil {
successSegmentList = append(successSegmentList, item.segments...)
results = append(results, item.result)
} else {
failureSegmentList = append(failureSegmentList, item.segments...)
errors = append(errors, item.err)
}
}
results := make([]R, 0, len(tasks))
for result := range resultCh {
results = append(results, result)
accessDataRatio := 1.0
totalSegments := len(successSegmentList) + len(failureSegmentList)
if totalSegments > 0 {
accessDataRatio = float64(len(successSegmentList)) / float64(totalSegments)
if accessDataRatio < 1.0 {
log.Info("partial result executed successfully",
zap.String("taskType", taskType),
zap.Float64("successRatio", accessDataRatio),
zap.Float64("partialResultRequiredDataRatio", partialResultRequiredDataRatio),
zap.Int("totalSegments", totalSegments),
zap.Int64s("failureSegmentList", failureSegmentList),
)
}
}
if accessDataRatio < partialResultRequiredDataRatio {
return nil, merr.Combine(errors...)
}
return results, nil
}
@ -891,7 +973,7 @@ func (sd *shardDelegator) UpdateSchema(ctx context.Context, schema *schemapb.Col
log.Info("delegator received update schema event")
sealed, growing, version, err := sd.distribution.PinReadableSegments()
sealed, growing, version, err := sd.distribution.PinReadableSegments(1.0)
if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
return err
@ -1011,6 +1093,7 @@ func (sd *shardDelegator) loadPartitionStats(ctx context.Context, partStatsVersi
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader,
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
queryView *channelQueryView,
) (ShardDelegator, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
zap.Int64("replicaID", replicaID),
@ -1041,7 +1124,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
segmentManager: manager.Segment,
workerManager: workerManager,
lifetime: lifetime.NewLifetime(lifetime.Initializing),
distribution: NewDistribution(channel),
distribution: NewDistribution(channel, queryView),
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock,
[]string{fmt.Sprint(paramtable.GetNodeID()), channel}),
pkOracle: pkoracle.NewPkOracle(),

View File

@ -959,70 +959,45 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
return nil
}
func (sd *shardDelegator) SyncTargetVersion(
newVersion int64,
partitions []int64,
growingInTarget []int64,
sealedInTarget []int64,
droppedInTarget []int64,
checkpoint *msgpb.MsgPosition,
deleteSeekPos *msgpb.MsgPosition,
) {
growings := sd.segmentManager.GetBy(
segments.WithType(segments.SegmentTypeGrowing),
segments.WithChannel(sd.vchannelName),
)
sealedSet := typeutil.NewUniqueSet(sealedInTarget...)
growingSet := typeutil.NewUniqueSet(growingInTarget...)
droppedSet := typeutil.NewUniqueSet(droppedInTarget...)
redundantGrowing := typeutil.NewUniqueSet()
for _, s := range growings {
if growingSet.Contain(s.ID()) {
continue
func (sd *shardDelegator) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
sd.distribution.SyncTargetVersion(action, partitions)
// clean delete buffer after distribution becomes serviceable
if sd.distribution.queryView.Serviceable() {
checkpoint := action.GetCheckpoint()
deleteSeekPos := action.GetDeleteCP()
if deleteSeekPos == nil {
// for compatible with 2.4, we use checkpoint as deleteCP when deleteCP is nil
deleteSeekPos = checkpoint
log.Info("use checkpoint as deleteCP",
zap.String("channelName", sd.vchannelName),
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(action.GetCheckpoint().GetTimestamp())))
}
// sealed segment already exists, make growing segment redundant
if sealedSet.Contain(s.ID()) {
redundantGrowing.Insert(s.ID())
}
// sealed segment already dropped, make growing segment redundant
if droppedSet.Contain(s.ID()) {
redundantGrowing.Insert(s.ID())
start := time.Now()
sizeBeforeClean, _ := sd.deleteBuffer.Size()
l0NumBeforeClean := len(sd.deleteBuffer.ListL0())
sd.deleteBuffer.UnRegister(deleteSeekPos.GetTimestamp())
sizeAfterClean, _ := sd.deleteBuffer.Size()
l0NumAfterClean := len(sd.deleteBuffer.ListL0())
if sizeAfterClean < sizeBeforeClean || l0NumAfterClean < l0NumBeforeClean {
log.Info("clean delete buffer",
zap.String("channel", sd.vchannelName),
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(deleteSeekPos.GetTimestamp())),
zap.Time("channelCP", tsoutil.PhysicalTime(checkpoint.GetTimestamp())),
zap.Int64("sizeBeforeClean", sizeBeforeClean),
zap.Int64("sizeAfterClean", sizeAfterClean),
zap.Int("l0NumBeforeClean", l0NumBeforeClean),
zap.Int("l0NumAfterClean", l0NumAfterClean),
zap.Duration("cost", time.Since(start)),
)
}
sd.RefreshLevel0DeletionStats()
}
redundantGrowingIDs := redundantGrowing.Collect()
if len(redundantGrowing) > 0 {
log.Warn("found redundant growing segments",
zap.Int64s("growingSegments", redundantGrowingIDs))
}
sd.distribution.SyncTargetVersion(newVersion, partitions, growingInTarget, sealedInTarget, redundantGrowingIDs)
start := time.Now()
sizeBeforeClean, _ := sd.deleteBuffer.Size()
l0NumBeforeClean := len(sd.deleteBuffer.ListL0())
sd.deleteBuffer.UnRegister(deleteSeekPos.GetTimestamp())
sizeAfterClean, _ := sd.deleteBuffer.Size()
l0NumAfterClean := len(sd.deleteBuffer.ListL0())
if sizeAfterClean < sizeBeforeClean || l0NumAfterClean < l0NumBeforeClean {
log.Info("clean delete buffer",
zap.String("channel", sd.vchannelName),
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(deleteSeekPos.GetTimestamp())),
zap.Time("channelCP", tsoutil.PhysicalTime(checkpoint.GetTimestamp())),
zap.Int64("sizeBeforeClean", sizeBeforeClean),
zap.Int64("sizeAfterClean", sizeAfterClean),
zap.Int("l0NumBeforeClean", l0NumBeforeClean),
zap.Int("l0NumAfterClean", l0NumAfterClean),
zap.Duration("cost", time.Since(start)),
)
}
sd.RefreshLevel0DeletionStats()
}
func (sd *shardDelegator) GetQueryView() *channelQueryView {
return sd.distribution.GetQueryView()
func (sd *shardDelegator) GetChannelQueryView() *channelQueryView {
return sd.distribution.queryView
}
func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {

View File

@ -194,7 +194,7 @@ func (s *DelegatorDataSuite) genCollectionWithFunction() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err)
s.delegator = delegator.(*shardDelegator)
}
@ -216,7 +216,7 @@ func (s *DelegatorDataSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator)
s.Require().True(ok)
@ -419,7 +419,16 @@ func (s *DelegatorDataSuite) TestProcessDelete() {
s.Require().NoError(err)
// sync target version, make delegator serviceable
s.delegator.SyncTargetVersion(time.Now().UnixNano(), []int64{500}, []int64{1001}, []int64{1000}, nil, &msgpb.MsgPosition{}, &msgpb.MsgPosition{})
s.delegator.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 2001,
GrowingInTarget: []int64{1001},
SealedSegmentRowCount: map[int64]int64{
1000: 100,
},
DroppedInTarget: []int64{},
Checkpoint: &msgpb.MsgPosition{},
DeleteCP: &msgpb.MsgPosition{},
}, []int64{500, 501})
s.delegator.ProcessDelete([]*DeleteData{
{
PartitionID: 500,
@ -799,7 +808,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, nil)
}, 10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err)
growing0 := segments.NewMockSegment(s.T())
@ -1422,8 +1431,15 @@ func (s *DelegatorDataSuite) TestSyncTargetVersion() {
s.manager.Segment.Put(context.Background(), segments.SegmentTypeGrowing, ms)
}
s.delegator.SyncTargetVersion(int64(5), []int64{1}, []int64{1}, []int64{2}, []int64{3, 4}, &msgpb.MsgPosition{}, &msgpb.MsgPosition{})
s.Equal(int64(5), s.delegator.GetQueryView().GetVersion())
s.delegator.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 5,
GrowingInTarget: []int64{1},
SealedInTarget: []int64{2},
DroppedInTarget: []int64{3, 4},
Checkpoint: &msgpb.MsgPosition{},
DeleteCP: &msgpb.MsgPosition{},
}, []int64{500, 501})
s.Equal(int64(5), s.delegator.GetChannelQueryView().GetVersion())
}
func (s *DelegatorDataSuite) TestLevel0Deletions() {

View File

@ -163,7 +163,7 @@ func (s *DelegatorSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
}
@ -202,7 +202,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Error(err)
})
@ -245,7 +245,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err)
})
}
@ -325,7 +325,19 @@ func (s *DelegatorSuite) initSegments() {
Version: 2001,
},
)
s.delegator.SyncTargetVersion(2001, []int64{500, 501}, []int64{1004}, []int64{1000, 1001, 1002, 1003}, []int64{}, &msgpb.MsgPosition{}, &msgpb.MsgPosition{})
s.delegator.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 2001,
GrowingInTarget: []int64{1004},
SealedSegmentRowCount: map[int64]int64{
1000: 100,
1001: 100,
1002: 100,
1003: 100,
},
DroppedInTarget: []int64{},
Checkpoint: &msgpb.MsgPosition{},
DeleteCP: &msgpb.MsgPosition{},
}, []int64{500, 501})
}
func (s *DelegatorSuite) TestSearch() {
@ -888,7 +900,8 @@ func (s *DelegatorSuite) TestQueryStream() {
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName},
}, server)
s.True(errors.Is(err, mockErr))
s.Error(err)
s.ErrorContains(err, "segments not loaded in any worker")
})
s.Run("worker_return_error", func() {
@ -1251,7 +1264,7 @@ func (s *DelegatorSuite) TestUpdateSchema() {
s.Run("worker_manager_error", func() {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).RunAndReturn(func(ctx context.Context, i int64) (cluster.Worker, error) {
return nil, merr.WrapErrServiceInternal("mocked")
}).Once()
})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

View File

@ -155,7 +155,7 @@ func (s *StreamingForwardSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator)
@ -185,7 +185,13 @@ func (s *StreamingForwardSuite) TestBFStreamingForward() {
PartitionID: 1,
SegmentID: 102,
})
delegator.distribution.SyncTargetVersion(1, []int64{1}, []int64{100}, []int64{101, 102}, nil)
delegator.distribution.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1,
GrowingInTarget: []int64{100},
SealedInTarget: []int64{101, 102},
DroppedInTarget: nil,
Checkpoint: nil,
}, []int64{1})
// Setup pk oracle
// empty bfs will not match
@ -238,7 +244,13 @@ func (s *StreamingForwardSuite) TestDirectStreamingForward() {
PartitionID: 1,
SegmentID: 102,
})
delegator.distribution.SyncTargetVersion(1, []int64{1}, []int64{100}, []int64{101, 102}, nil)
delegator.distribution.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1,
GrowingInTarget: []int64{100},
SealedInTarget: []int64{101, 102},
DroppedInTarget: nil,
Checkpoint: nil,
}, []int64{1})
// Setup pk oracle
// empty bfs will not match
@ -386,7 +398,7 @@ func (s *GrowingMergeL0Suite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager)
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator)

View File

@ -17,6 +17,7 @@
package delegator
import (
"fmt"
"sync"
"github.com/samber/lo"
@ -25,6 +26,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -56,12 +58,26 @@ func getClosedCh() chan struct{} {
}
// channelQueryView maintains the sealed segment list which should be used for search/query.
// for new delegator, will got a new channelQueryView from WatchChannel, and get the queryView update from querycoord before it becomes serviceable
// after delegator becomes serviceable, it only update the queryView by SyncTargetVersion
type channelQueryView struct {
sealedSegments []int64 // sealed segment list which should be used for search/query
partitions typeutil.UniqueSet // partitions list which sealed segments belong to
version int64 // version of current query view, same as targetVersion in qc
growingSegments typeutil.UniqueSet // growing segment list which should be used for search/query
sealedSegmentRowCount map[int64]int64 // sealed segment list which should be used for search/query, segmentID -> row count
partitions typeutil.UniqueSet // partitions list which sealed segments belong to
version int64 // version of current query view, same as targetVersion in qc
serviceable *atomic.Bool
loadedRatio *atomic.Float64 // loaded ratio of current query view, set serviceable to true if loadedRatio == 1.0
unloadedSealedSegments []SegmentEntry // workerID -> -1
}
func NewChannelQueryView(growings []int64, sealedSegmentRowCount map[int64]int64, partitions []int64, version int64) *channelQueryView {
return &channelQueryView{
growingSegments: typeutil.NewUniqueSet(growings...),
sealedSegmentRowCount: sealedSegmentRowCount,
partitions: typeutil.NewUniqueSet(partitions...),
version: version,
loadedRatio: atomic.NewFloat64(0),
}
}
func (q *channelQueryView) GetVersion() int64 {
@ -69,7 +85,11 @@ func (q *channelQueryView) GetVersion() int64 {
}
func (q *channelQueryView) Serviceable() bool {
return q.serviceable.Load()
return q.loadedRatio.Load() >= 1.0
}
func (q *channelQueryView) GetLoadedRatio() float64 {
return q.loadedRatio.Load()
}
// distribution is the struct to store segment distribution.
@ -107,22 +127,17 @@ type SegmentEntry struct {
Offline bool // if delegator failed to execute forwardDelete/Query/Search on segment, it will be offline
}
// NewDistribution creates a new distribution instance with all field initialized.
func NewDistribution(channelName string) *distribution {
func NewDistribution(channelName string, queryView *channelQueryView) *distribution {
dist := &distribution{
channelName: channelName,
growingSegments: make(map[UniqueID]SegmentEntry),
sealedSegments: make(map[UniqueID]SegmentEntry),
snapshots: typeutil.NewConcurrentMap[int64, *snapshot](),
current: atomic.NewPointer[snapshot](nil),
queryView: &channelQueryView{
serviceable: atomic.NewBool(false),
partitions: typeutil.NewSet[int64](),
version: initialTargetVersion,
},
queryView: queryView,
}
dist.genSnapshot()
dist.updateServiceable("NewDistribution")
return dist
}
@ -132,12 +147,14 @@ func (d *distribution) SetIDFOracle(idfOracle IDFOracle) {
d.idfOracle = idfOracle
}
func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) {
// return segment distribution in query view
func (d *distribution) PinReadableSegments(requiredLoadRatio float64, partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) {
d.mut.RLock()
defer d.mut.RUnlock()
if !d.Serviceable() {
return nil, nil, -1, merr.WrapErrChannelNotAvailable("channel distribution is not serviceable")
if d.queryView.GetLoadedRatio() < requiredLoadRatio {
return nil, nil, -1, merr.WrapErrChannelNotAvailable(d.channelName,
fmt.Sprintf("channel distribution is not serviceable, required load ratio is %f, current load ratio is %f", requiredLoadRatio, d.queryView.GetLoadedRatio()))
}
current := d.current.Load()
@ -153,6 +170,15 @@ func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []Snapsh
targetVersion := current.GetTargetVersion()
filterReadable := d.readableFilter(targetVersion)
sealed, growing = d.filterSegments(sealed, growing, filterReadable)
if len(d.queryView.unloadedSealedSegments) > 0 {
// append distribution of unloaded segment
sealed = append(sealed, SnapshotItem{
NodeID: -1,
Segments: d.queryView.unloadedSealedSegments,
})
}
return
}
@ -213,26 +239,50 @@ func (d *distribution) getTargetVersion() int64 {
// Serviceable returns wether current snapshot is serviceable.
func (d *distribution) Serviceable() bool {
return d.queryView.serviceable.Load()
return d.queryView.Serviceable()
}
// for now, delegator become serviceable only when watchDmChannel is done
// so we regard all needed growing is loaded and we compute loadRatio based on sealed segments
func (d *distribution) updateServiceable(triggerAction string) {
if d.queryView.version != initialTargetVersion {
serviceable := true
for _, s := range d.queryView.sealedSegments {
if entry, ok := d.sealedSegments[s]; !ok || entry.Offline {
serviceable = false
break
}
}
if serviceable != d.queryView.serviceable.Load() {
d.queryView.serviceable.Store(serviceable)
log.Info("channel distribution serviceable changed",
zap.String("channel", d.channelName),
zap.Bool("serviceable", serviceable),
zap.String("action", triggerAction))
loadedSealedSegments := int64(0)
totalSealedRowCount := int64(0)
unloadedSealedSegments := make([]SegmentEntry, 0)
for id, rowCount := range d.queryView.sealedSegmentRowCount {
if entry, ok := d.sealedSegments[id]; ok && !entry.Offline {
loadedSealedSegments += rowCount
} else {
unloadedSealedSegments = append(unloadedSealedSegments, SegmentEntry{SegmentID: id, NodeID: -1})
}
totalSealedRowCount += rowCount
}
// unloaded segment entry list for partial result
d.queryView.unloadedSealedSegments = unloadedSealedSegments
loadedRatio := 0.0
if len(d.queryView.sealedSegmentRowCount) == 0 {
loadedRatio = 1.0
} else if loadedSealedSegments == 0 {
loadedRatio = 0.0
} else {
loadedRatio = float64(loadedSealedSegments) / float64(totalSealedRowCount)
}
serviceable := loadedRatio >= 1.0
if serviceable != d.queryView.Serviceable() {
log.Info("channel distribution serviceable changed",
zap.String("channel", d.channelName),
zap.Bool("serviceable", serviceable),
zap.Float64("loadedRatio", loadedRatio),
zap.Int64("loadedSealedRowCount", loadedSealedSegments),
zap.Int64("totalSealedRowCount", totalSealedRowCount),
zap.Int("unloadedSealedSegmentNum", len(unloadedSealedSegments)),
zap.Int("totalSealedSegmentNum", len(d.queryView.sealedSegmentRowCount)),
zap.String("action", triggerAction))
}
d.queryView.loadedRatio.Store(loadedRatio)
}
// AddDistributions add multiple segment entries.
@ -257,8 +307,14 @@ func (d *distribution) AddDistributions(entries ...SegmentEntry) {
// remain the target version for already loaded segment to void skipping this segment when executing search
entry.TargetVersion = oldEntry.TargetVersion
} else {
// waiting for sync target version, to become readable
entry.TargetVersion = unreadableTargetVersion
_, ok := d.queryView.sealedSegmentRowCount[entry.SegmentID]
if ok || d.queryView.growingSegments.Contain(entry.SegmentID) {
// set segment version to query view version, to support partial result
entry.TargetVersion = d.queryView.GetVersion()
} else {
// set segment version to unreadableTargetVersion, if it's not in query view
entry.TargetVersion = unreadableTargetVersion
}
}
d.sealedSegments[entry.SegmentID] = entry
}
@ -306,65 +362,73 @@ func (d *distribution) MarkOfflineSegments(segmentIDs ...int64) {
}
}
// UpdateTargetVersion update readable segment version
func (d *distribution) SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, redundantGrowings []int64) {
// update readable channel view
// 1. update readable channel view to support partial result before distribution is serviceable
// 2. update readable channel view to support full result after new distribution is serviceable
// Notice: if we don't need to be compatible with 2.5.x, we can just update new query view to support query,
// and new query view will become serviceable automatically, a sync action after distribution is serviceable is unnecessary
func (d *distribution) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
d.mut.Lock()
defer d.mut.Unlock()
for _, segmentID := range growingInTarget {
entry, ok := d.growingSegments[segmentID]
if !ok {
log.Warn("readable growing segment lost, consume from dml seems too slow",
zap.Int64("segmentID", segmentID))
continue
}
entry.TargetVersion = newVersion
d.growingSegments[segmentID] = entry
}
for _, segmentID := range redundantGrowings {
entry, ok := d.growingSegments[segmentID]
if !ok {
continue
}
entry.TargetVersion = redundantTargetVersion
d.growingSegments[segmentID] = entry
}
for _, segmentID := range sealedInTarget {
entry, ok := d.sealedSegments[segmentID]
if !ok {
continue
}
entry.TargetVersion = newVersion
d.sealedSegments[segmentID] = entry
}
oldValue := d.queryView.version
d.queryView = &channelQueryView{
sealedSegments: sealedInTarget,
partitions: typeutil.NewUniqueSet(partitions...),
version: newVersion,
serviceable: d.queryView.serviceable,
growingSegments: typeutil.NewUniqueSet(action.GetGrowingInTarget()...),
sealedSegmentRowCount: action.GetSealedSegmentRowCount(),
partitions: typeutil.NewUniqueSet(partitions...),
version: action.GetTargetVersion(),
loadedRatio: atomic.NewFloat64(0),
}
// update working partition list
d.genSnapshot()
sealedSet := typeutil.NewUniqueSet(action.GetSealedInTarget()...)
droppedSet := typeutil.NewUniqueSet(action.GetDroppedInTarget()...)
redundantGrowings := make([]int64, 0)
for _, s := range d.growingSegments {
// sealed segment already exists or dropped, make growing segment redundant
if sealedSet.Contain(s.SegmentID) || droppedSet.Contain(s.SegmentID) {
s.TargetVersion = redundantTargetVersion
d.growingSegments[s.SegmentID] = s
redundantGrowings = append(redundantGrowings, s.SegmentID)
}
}
d.queryView.growingSegments.Range(func(s UniqueID) bool {
entry, ok := d.growingSegments[s]
if !ok {
log.Warn("readable growing segment lost, consume from dml seems too slow",
zap.Int64("segmentID", s))
return true
}
entry.TargetVersion = action.GetTargetVersion()
d.growingSegments[s] = entry
return true
})
for id := range d.queryView.sealedSegmentRowCount {
entry, ok := d.sealedSegments[id]
if !ok {
continue
}
entry.TargetVersion = action.GetTargetVersion()
d.sealedSegments[id] = entry
}
d.genSnapshot()
if d.idfOracle != nil {
d.idfOracle.SetNext(d.current.Load())
d.idfOracle.LazyRemoveGrowings(newVersion, redundantGrowings...)
d.idfOracle.LazyRemoveGrowings(action.GetTargetVersion(), redundantGrowings...)
}
// if sealed segment in leader view is less than sealed segment in target, set delegator to unserviceable
d.updateServiceable("SyncTargetVersion")
log.Info("Update channel query view",
zap.String("channel", d.channelName),
zap.Int64s("partitions", partitions),
zap.Int64("oldVersion", oldValue),
zap.Int64("newVersion", newVersion),
zap.Int("growingSegmentNum", len(growingInTarget)),
zap.Int("sealedSegmentNum", len(sealedInTarget)),
zap.Int64("newVersion", action.GetTargetVersion()),
zap.Bool("serviceable", d.queryView.Serviceable()),
zap.Float64("loadedRatio", d.queryView.GetLoadedRatio()),
zap.Int("growingSegmentNum", len(action.GetGrowingInTarget())),
zap.Int("sealedSegmentNum", len(action.GetSealedInTarget())),
)
}

View File

@ -21,7 +21,10 @@ import (
"time"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
)
type DistributionSuite struct {
@ -30,7 +33,7 @@ type DistributionSuite struct {
}
func (s *DistributionSuite) SetupTest() {
s.dist = NewDistribution("channel-1")
s.dist = NewDistribution("channel-1", NewChannelQueryView(nil, nil, nil, initialTargetVersion))
}
func (s *DistributionSuite) TearDownTest() {
@ -44,6 +47,7 @@ func (s *DistributionSuite) TestAddDistribution() {
growing []SegmentEntry
expected []SnapshotItem
expectedSignalClosed bool
expectedLoadRatio float64
}
cases := []testCase{
@ -177,8 +181,10 @@ func (s *DistributionSuite) TestAddDistribution() {
s.SetupTest()
defer s.TearDownTest()
s.dist.AddGrowing(tc.growing...)
s.dist.SyncTargetVersion(1000, nil, nil, nil, nil)
_, _, version, err := s.dist.PinReadableSegments()
s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1000,
}, nil)
_, _, version, err := s.dist.PinReadableSegments(1.0)
s.Require().NoError(err)
s.dist.AddDistributions(tc.input...)
sealed, _ := s.dist.PeekSegments(false)
@ -225,11 +231,6 @@ func (s *DistributionSuite) TestAddGrowing() {
}
cases := []testCase{
{
tag: "nil input",
input: nil,
expected: []SegmentEntry{},
},
{
tag: "normal_case",
input: []SegmentEntry{
@ -261,8 +262,11 @@ func (s *DistributionSuite) TestAddGrowing() {
defer s.TearDownTest()
s.dist.AddGrowing(tc.input...)
s.dist.SyncTargetVersion(1000, tc.workingParts, []int64{1, 2}, nil, nil)
_, growing, version, err := s.dist.PinReadableSegments()
s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1000,
GrowingInTarget: []int64{1, 2},
}, tc.workingParts)
_, growing, version, err := s.dist.PinReadableSegments(1.0)
s.Require().NoError(err)
defer s.dist.Unpin(version)
@ -452,15 +456,19 @@ func (s *DistributionSuite) TestRemoveDistribution() {
growingIDs := lo.Map(tc.presetGrowing, func(item SegmentEntry, idx int) int64 {
return item.SegmentID
})
sealedIDs := lo.Map(tc.presetSealed, func(item SegmentEntry, idx int) int64 {
return item.SegmentID
sealedSegmentRowCount := lo.SliceToMap(tc.presetSealed, func(item SegmentEntry) (int64, int64) {
return item.SegmentID, 100
})
s.dist.SyncTargetVersion(time.Now().Unix(), nil, growingIDs, sealedIDs, nil)
s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1000,
GrowingInTarget: growingIDs,
SealedSegmentRowCount: sealedSegmentRowCount,
}, nil)
var version int64
if tc.withMockRead {
var err error
_, _, version, err = s.dist.PinReadableSegments()
_, _, version, err = s.dist.PinReadableSegments(1.0)
s.Require().NoError(err)
}
@ -675,10 +683,14 @@ func (s *DistributionSuite) TestMarkOfflineSegments() {
defer s.TearDownTest()
s.dist.AddDistributions(tc.input...)
sealedSegmentID := lo.Map(tc.input, func(t SegmentEntry, _ int) int64 {
return t.SegmentID
sealedSegmentRowCount := lo.SliceToMap(tc.input, func(t SegmentEntry) (int64, int64) {
return t.SegmentID, 100
})
s.dist.SyncTargetVersion(1000, nil, nil, sealedSegmentID, nil)
s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1000,
SealedSegmentRowCount: sealedSegmentRowCount,
DroppedInTarget: nil,
}, nil)
s.dist.MarkOfflineSegments(tc.offlines...)
s.Equal(tc.serviceable, s.dist.Serviceable())
@ -740,29 +752,187 @@ func (s *DistributionSuite) Test_SyncTargetVersion() {
s.dist.AddGrowing(growing...)
s.dist.AddDistributions(sealed...)
s.dist.SyncTargetVersion(2, []int64{1}, []int64{2, 3}, []int64{6}, []int64{})
s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 2,
GrowingInTarget: []int64{1},
SealedSegmentRowCount: map[int64]int64{4: 100, 5: 200},
DroppedInTarget: []int64{6},
}, []int64{1})
s1, s2, _, err := s.dist.PinReadableSegments()
s1, s2, _, err := s.dist.PinReadableSegments(1.0)
s.Require().NoError(err)
s.Len(s1[0].Segments, 1)
s.Len(s2, 2)
s.Len(s1[0].Segments, 2)
s.Len(s2, 1)
s1, s2, _ = s.dist.PinOnlineSegments()
s.Len(s1[0].Segments, 3)
s.Len(s2, 3)
s.dist.queryView.serviceable.Store(true)
s.dist.SyncTargetVersion(2, []int64{1}, []int64{222}, []int64{}, []int64{})
s.True(s.dist.Serviceable())
s.dist.SyncTargetVersion(2, []int64{1}, []int64{}, []int64{333}, []int64{})
s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 2,
GrowingInTarget: []int64{1},
SealedSegmentRowCount: map[int64]int64{333: 100},
DroppedInTarget: []int64{},
}, []int64{1})
s.False(s.dist.Serviceable())
s.dist.SyncTargetVersion(2, []int64{1}, []int64{}, []int64{333}, []int64{1, 2, 3})
_, _, _, err = s.dist.PinReadableSegments()
_, _, _, err = s.dist.PinReadableSegments(1.0)
s.Error(err)
}
func TestDistributionSuite(t *testing.T) {
suite.Run(t, new(DistributionSuite))
}
func TestNewChannelQueryView(t *testing.T) {
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 200, 6: 300}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
assert.NotNil(t, view)
assert.ElementsMatch(t, growings, view.growingSegments.Collect())
assert.ElementsMatch(t, lo.Keys(sealedWithRowCount), lo.Keys(view.sealedSegmentRowCount))
assert.True(t, view.partitions.Contain(7))
assert.True(t, view.partitions.Contain(8))
assert.True(t, view.partitions.Contain(9))
assert.Equal(t, version, view.version)
assert.Equal(t, float64(0), view.loadedRatio.Load())
assert.False(t, view.Serviceable())
}
func TestDistribution_NewDistribution(t *testing.T) {
channelName := "test_channel"
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 200, 6: 300}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
dist := NewDistribution(channelName, view)
assert.NotNil(t, dist)
assert.Equal(t, channelName, dist.channelName)
assert.Equal(t, view, dist.queryView)
assert.NotNil(t, dist.growingSegments)
assert.NotNil(t, dist.sealedSegments)
assert.NotNil(t, dist.snapshots)
assert.NotNil(t, dist.current)
}
func TestDistribution_UpdateServiceable(t *testing.T) {
channelName := "test_channel"
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 100, 6: 100}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
dist := NewDistribution(channelName, view)
// Test with no segments loaded
dist.updateServiceable("test")
assert.False(t, dist.Serviceable())
assert.Equal(t, float64(0), dist.queryView.GetLoadedRatio())
// Test with some segments loaded
dist.sealedSegments[4] = SegmentEntry{
SegmentID: 4,
Offline: false,
}
dist.growingSegments[1] = SegmentEntry{
SegmentID: 1,
}
dist.updateServiceable("test")
assert.False(t, dist.Serviceable())
assert.Equal(t, float64(2)/float64(6), dist.queryView.GetLoadedRatio())
// Test with all segments loaded
for id := range sealedWithRowCount {
dist.sealedSegments[id] = SegmentEntry{
SegmentID: id,
Offline: false,
}
}
for _, id := range growings {
dist.growingSegments[id] = SegmentEntry{
SegmentID: id,
}
}
dist.updateServiceable("test")
assert.True(t, dist.Serviceable())
assert.Equal(t, float64(1), dist.queryView.GetLoadedRatio())
}
func TestDistribution_SyncTargetVersion(t *testing.T) {
channelName := "test_channel"
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 100, 6: 100}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
dist := NewDistribution(channelName, view)
// Add some initial segments
dist.growingSegments[1] = SegmentEntry{
SegmentID: 1,
}
dist.sealedSegments[4] = SegmentEntry{
SegmentID: 4,
}
// Create a new sync action
action := &querypb.SyncAction{
GrowingInTarget: []int64{1, 2},
SealedSegmentRowCount: map[int64]int64{4: 100, 5: 100},
DroppedInTarget: []int64{3},
TargetVersion: version + 1,
}
// Sync the view
dist.SyncTargetVersion(action, partitions)
// Verify the changes
assert.Equal(t, action.GetTargetVersion(), dist.queryView.version)
assert.ElementsMatch(t, action.GetGrowingInTarget(), dist.queryView.growingSegments.Collect())
assert.ElementsMatch(t, lo.Keys(action.GetSealedSegmentRowCount()), lo.Keys(dist.queryView.sealedSegmentRowCount))
assert.True(t, dist.queryView.partitions.Contain(7))
assert.True(t, dist.queryView.partitions.Contain(8))
assert.True(t, dist.queryView.partitions.Contain(9))
// Verify growing segment target version
assert.Equal(t, action.GetTargetVersion(), dist.growingSegments[1].TargetVersion)
// Verify sealed segment target version
assert.Equal(t, action.GetTargetVersion(), dist.sealedSegments[4].TargetVersion)
}
func TestDistribution_MarkOfflineSegments(t *testing.T) {
channelName := "test_channel"
view := NewChannelQueryView([]int64{}, map[int64]int64{1: 100, 2: 200}, []int64{}, 0)
dist := NewDistribution(channelName, view)
// Add some segments
dist.sealedSegments[1] = SegmentEntry{
SegmentID: 1,
NodeID: 100,
Version: 1,
}
dist.sealedSegments[2] = SegmentEntry{
SegmentID: 2,
NodeID: 100,
Version: 1,
}
// Mark segments offline
dist.MarkOfflineSegments(1, 2)
// Verify the changes
assert.True(t, dist.sealedSegments[1].Offline)
assert.True(t, dist.sealedSegments[2].Offline)
assert.Equal(t, int64(-1), dist.sealedSegments[1].NodeID)
assert.Equal(t, int64(-1), dist.sealedSegments[2].NodeID)
assert.Equal(t, unreadableTargetVersion, dist.sealedSegments[1].Version)
assert.Equal(t, unreadableTargetVersion, dist.sealedSegments[2].Version)
}

View File

@ -8,8 +8,6 @@ import (
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
mock "github.com/stretchr/testify/mock"
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
querypb "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -243,12 +241,12 @@ func (_c *MockShardDelegator_GetPartitionStatsVersions_Call) RunAndReturn(run fu
return _c
}
// GetQueryView provides a mock function with no fields
func (_m *MockShardDelegator) GetQueryView() *channelQueryView {
// GetChannelQueryView provides a mock function with no fields
func (_m *MockShardDelegator) GetChannelQueryView() *channelQueryView {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetQueryView")
panic("no return value specified for GetChannelQueryView")
}
var r0 *channelQueryView
@ -263,29 +261,29 @@ func (_m *MockShardDelegator) GetQueryView() *channelQueryView {
return r0
}
// MockShardDelegator_GetQueryView_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryView'
type MockShardDelegator_GetQueryView_Call struct {
// MockShardDelegator_GetChannelQueryView_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannelQueryView'
type MockShardDelegator_GetChannelQueryView_Call struct {
*mock.Call
}
// GetQueryView is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) GetQueryView() *MockShardDelegator_GetQueryView_Call {
return &MockShardDelegator_GetQueryView_Call{Call: _e.mock.On("GetQueryView")}
// GetChannelQueryView is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) GetChannelQueryView() *MockShardDelegator_GetChannelQueryView_Call {
return &MockShardDelegator_GetChannelQueryView_Call{Call: _e.mock.On("GetChannelQueryView")}
}
func (_c *MockShardDelegator_GetQueryView_Call) Run(run func()) *MockShardDelegator_GetQueryView_Call {
func (_c *MockShardDelegator_GetChannelQueryView_Call) Run(run func()) *MockShardDelegator_GetChannelQueryView_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardDelegator_GetQueryView_Call) Return(_a0 *channelQueryView) *MockShardDelegator_GetQueryView_Call {
func (_c *MockShardDelegator_GetChannelQueryView_Call) Return(_a0 *channelQueryView) *MockShardDelegator_GetChannelQueryView_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockShardDelegator_GetQueryView_Call) RunAndReturn(run func() *channelQueryView) *MockShardDelegator_GetQueryView_Call {
func (_c *MockShardDelegator_GetChannelQueryView_Call) RunAndReturn(run func() *channelQueryView) *MockShardDelegator_GetChannelQueryView_Call {
_c.Call.Return(run)
return _c
}
@ -1037,9 +1035,9 @@ func (_c *MockShardDelegator_SyncPartitionStats_Call) RunAndReturn(run func(cont
return _c
}
// SyncTargetVersion provides a mock function with given fields: newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos
func (_m *MockShardDelegator) SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition) {
_m.Called(newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos)
// SyncTargetVersion provides a mock function with given fields: action, partitions
func (_m *MockShardDelegator) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
_m.Called(action, partitions)
}
// MockShardDelegator_SyncTargetVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncTargetVersion'
@ -1048,20 +1046,15 @@ type MockShardDelegator_SyncTargetVersion_Call struct {
}
// SyncTargetVersion is a helper method to define mock.On call
// - newVersion int64
// - action *querypb.SyncAction
// - partitions []int64
// - growingInTarget []int64
// - sealedInTarget []int64
// - droppedInTarget []int64
// - checkpoint *msgpb.MsgPosition
// - deleteSeekPos *msgpb.MsgPosition
func (_e *MockShardDelegator_Expecter) SyncTargetVersion(newVersion interface{}, partitions interface{}, growingInTarget interface{}, sealedInTarget interface{}, droppedInTarget interface{}, checkpoint interface{}, deleteSeekPos interface{}) *MockShardDelegator_SyncTargetVersion_Call {
return &MockShardDelegator_SyncTargetVersion_Call{Call: _e.mock.On("SyncTargetVersion", newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos)}
func (_e *MockShardDelegator_Expecter) SyncTargetVersion(action interface{}, partitions interface{}) *MockShardDelegator_SyncTargetVersion_Call {
return &MockShardDelegator_SyncTargetVersion_Call{Call: _e.mock.On("SyncTargetVersion", action, partitions)}
}
func (_c *MockShardDelegator_SyncTargetVersion_Call) Run(run func(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition)) *MockShardDelegator_SyncTargetVersion_Call {
func (_c *MockShardDelegator_SyncTargetVersion_Call) Run(run func(action *querypb.SyncAction, partitions []int64)) *MockShardDelegator_SyncTargetVersion_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].([]int64), args[2].([]int64), args[3].([]int64), args[4].([]int64), args[5].(*msgpb.MsgPosition), args[6].(*msgpb.MsgPosition))
run(args[0].(*querypb.SyncAction), args[1].([]int64))
})
return _c
}
@ -1071,7 +1064,7 @@ func (_c *MockShardDelegator_SyncTargetVersion_Call) Return() *MockShardDelegato
return _c
}
func (_c *MockShardDelegator_SyncTargetVersion_Call) RunAndReturn(run func(int64, []int64, []int64, []int64, []int64, *msgpb.MsgPosition, *msgpb.MsgPosition)) *MockShardDelegator_SyncTargetVersion_Call {
func (_c *MockShardDelegator_SyncTargetVersion_Call) RunAndReturn(run func(*querypb.SyncAction, []int64)) *MockShardDelegator_SyncTargetVersion_Call {
_c.Run(run)
return _c
}

View File

@ -231,6 +231,7 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader, fmt.Sprint(req.GetReq().GetCollectionID())).Inc()
}
metrics.QueryNodePartialResultCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, fmt.Sprint(req.GetReq().GetCollectionID())).Inc()
}()
log.Debug("start do query with channel",

View File

@ -251,6 +251,13 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
}
}()
queryView := delegator.NewChannelQueryView(
channel.GetUnflushedSegmentIds(),
req.GetSealedSegmentRowCount(),
req.GetPartitionIDs(),
req.GetTargetVersion(),
)
delegator, err := delegator.NewShardDelegator(
ctx,
req.GetCollectionID(),
@ -264,6 +271,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
channel.GetSeekPosition().GetTimestamp(),
node.queryHook,
node.chunkManager,
queryView,
)
if err != nil {
log.Warn("failed to create shard delegator", zap.Error(err))
@ -1252,7 +1260,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
numOfGrowingRows += segment.InsertCount()
}
queryView := delegator.GetQueryView()
queryView := delegator.GetChannelQueryView()
leaderViews = append(leaderViews, &querypb.LeaderView{
Collection: delegator.Collection(),
Channel: key,
@ -1355,16 +1363,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
return id, action.GetCheckpoint().Timestamp
})
shardDelegator.AddExcludedSegments(flushedInfo)
deleteCP := action.GetDeleteCP()
if deleteCP == nil {
// for compatible with 2.4, we use checkpoint as deleteCP when deleteCP is nil
deleteCP = action.GetCheckpoint()
log.Info("use checkpoint as deleteCP",
zap.String("channelName", req.GetChannel()),
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(action.GetCheckpoint().GetTimestamp())))
}
shardDelegator.SyncTargetVersion(action.GetTargetVersion(), req.GetLoadMeta().GetPartitionIDs(), action.GetGrowingInTarget(),
action.GetSealedInTarget(), action.GetDroppedInTarget(), action.GetCheckpoint(), deleteCP)
shardDelegator.SyncTargetVersion(action, req.GetLoadMeta().GetPartitionIDs())
case querypb.SyncType_UpdatePartitionStats:
log.Info("sync update partition stats versions")
shardDelegator.SyncPartitionStats(ctx, action.PartitionStatsVersions)

View File

@ -1393,9 +1393,13 @@ func (suite *ServiceSuite) TestSearch_Failed() {
}
syncVersionAction := &querypb.SyncAction{
Type: querypb.SyncType_UpdateVersion,
SealedInTarget: []int64{1, 2, 3},
TargetVersion: time.Now().UnixMilli(),
Type: querypb.SyncType_UpdateVersion,
SealedSegmentRowCount: map[int64]int64{
1: 100,
2: 200,
3: 300,
},
TargetVersion: time.Now().UnixMilli(),
}
syncReq.Actions = []*querypb.SyncAction{syncVersionAction}

View File

@ -814,6 +814,18 @@ var (
cgoNameLabelName,
cgoTypeLabelName,
})
QueryNodePartialResultCount = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: milvusNamespace,
Subsystem: typeutil.QueryNodeRole,
Name: "partial_result_count",
Help: "count of partial result",
}, []string{
nodeIDLabelName,
queryTypeLabelName,
collectionIDLabelName,
})
)
// RegisterQueryNode registers QueryNode metrics
@ -885,6 +897,7 @@ func RegisterQueryNode(registry *prometheus.Registry) {
registry.MustRegister(QueryNodeDeleteBufferSize)
registry.MustRegister(QueryNodeDeleteBufferRowNum)
registry.MustRegister(QueryNodeCGOCallLatency)
registry.MustRegister(QueryNodePartialResultCount)
// Add cgo metrics
RegisterCGOMetrics(registry)
@ -933,6 +946,13 @@ func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) {
collectionIDLabelName: collectionIDLabel,
})
QueryNodePartialResultCount.
DeletePartialMatch(
prometheus.Labels{
nodeIDLabelName: nodeIDLabel,
collectionIDLabelName: collectionIDLabel,
})
QueryNodeSearchHitSegmentNum.
DeletePartialMatch(
prometheus.Labels{

View File

@ -281,6 +281,7 @@ message GetSegmentInfoResponse {
message GetShardLeadersRequest {
common.MsgBase base = 1;
int64 collectionID = 2;
bool with_unserviceable_shards = 3;
}
message GetShardLeadersResponse {
@ -297,6 +298,7 @@ message ShardLeadersList { // All leaders of all replicas of one shard
string channel_name = 1;
repeated int64 node_ids = 2;
repeated string node_addrs = 3;
repeated bool serviceable = 4;
}
message SyncNewCreatedPartitionRequest {
@ -334,6 +336,8 @@ message WatchDmChannelsRequest {
int64 offlineNodeID = 11;
int64 version = 12;
repeated index.IndexInfo index_info_list = 13;
int64 target_version = 14;
map<int64, int64> sealed_segment_row_count = 15; // segmentID -> row count, same as unflushedSegmentIds in vchannelInfo
}
message UnsubDmChannelRequest {
@ -625,7 +629,7 @@ message LeaderView {
}
message LeaderViewStatus {
bool serviceable = 10;
bool serviceable = 1;
}
message SegmentDist {
@ -725,6 +729,7 @@ message SyncAction {
msg.MsgPosition checkpoint = 11;
map<int64, int64> partition_stats_versions = 12;
msg.MsgPosition deleteCP = 13;
map<int64, int64> sealed_segment_row_count = 14; // segmentID -> row count, same as sealedInTarget
}
message SyncDistributionRequest {

File diff suppressed because it is too large Load Diff

View File

@ -2856,6 +2856,8 @@ type queryNodeConfig struct {
IDFEnableDisk ParamItem `refreshable:"true"`
IDFLocalPath ParamItem `refreshable:"true"`
IDFWriteConcurrenct ParamItem `refreshable:"true"`
// partial search
PartialResultRequiredDataRatio ParamItem `refreshable:"true"`
}
func (p *queryNodeConfig) init(base *BaseTable) {
@ -3786,6 +3788,15 @@ user-task-polling:
Export: true,
}
p.WorkerPoolingSize.Init(base.mgr)
p.PartialResultRequiredDataRatio = ParamItem{
Key: "proxy.partialResultRequiredDataRatio",
Version: "2.6.0",
DefaultValue: "1",
Doc: `partial result required data ratio, default to 1 which means disable partial result, otherwise, it will be used as the minimum data ratio for partial result`,
Export: true,
}
p.PartialResultRequiredDataRatio.Init(base.mgr)
}
// /////////////////////////////////////////////////////////////////////////////

View File

@ -489,6 +489,10 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, "/var/lib/milvus/data/mmap", Params.MmapDirPath.GetValue())
assert.Equal(t, 60*time.Second, Params.DiskSizeFetchInterval.GetAsDuration(time.Second))
assert.Equal(t, 1.0, Params.PartialResultRequiredDataRatio.GetAsFloat())
params.Save(Params.PartialResultRequiredDataRatio.Key, "0.8")
assert.Equal(t, 0.8, Params.PartialResultRequiredDataRatio.GetAsFloat())
})
t.Run("test dataCoordConfig", func(t *testing.T) {

View File

@ -13,6 +13,8 @@ package retry
import (
"context"
"runtime"
"strconv"
"time"
"github.com/cockroachdb/errors"
@ -23,6 +25,14 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func getCaller(skip int) string {
_, file, line, ok := runtime.Caller(skip)
if !ok {
return "unknown"
}
return file + ":" + strconv.Itoa(line)
}
// Do will run function with retry mechanism.
// fn is the func to run.
// Option can control the retry times and timeout.
@ -43,7 +53,10 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
for i := uint(0); c.attempts == 0 || i < c.attempts; i++ {
if err := fn(); err != nil {
if i%4 == 0 {
log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err))
log.Warn("retry func failed",
zap.Uint("retried", i),
zap.Error(err),
zap.String("caller", getCaller(2)))
}
if !IsRecoverable(err) {
@ -52,6 +65,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
)
if isContextErr && lastErr != nil {
return lastErr
@ -62,6 +76,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
log.Warn("retry func failed, not be retryable",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
)
return err
}
@ -73,6 +88,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
)
if isContextErr && lastErr != nil {
return lastErr
@ -88,6 +104,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
log.Warn("retry func failed, ctx done",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
)
return lastErr
}
@ -127,11 +144,22 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
for i := uint(0); i < c.attempts; i++ {
if shouldRetry, err := fn(); err != nil {
if i%4 == 0 {
log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err))
log.Warn("retry func failed",
zap.Uint("retried", i),
zap.String("caller", getCaller(2)),
zap.Error(err),
)
}
if !shouldRetry {
if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil {
isContextErr := errors.IsAny(err, context.Canceled, context.DeadlineExceeded)
log.Warn("retry func failed, not be recoverable",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
)
if isContextErr && lastErr != nil {
return lastErr
}
return err
@ -139,8 +167,14 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
deadline, ok := ctx.Deadline()
if ok && time.Until(deadline) < c.sleep {
// to avoid sleep until ctx done
if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil {
isContextErr := errors.IsAny(err, context.Canceled, context.DeadlineExceeded)
log.Warn("retry func failed, deadline",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
)
if isContextErr && lastErr != nil {
return lastErr
}
return err
@ -151,6 +185,11 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
select {
case <-time.After(c.sleep):
case <-ctx.Done():
log.Warn("retry func failed, ctx done",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
)
return lastErr
}
@ -162,6 +201,12 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
return nil
}
}
if lastErr != nil {
log.Warn("retry func failed, reach max retry",
zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
)
}
return lastErr
}

View File

@ -349,6 +349,7 @@ func (s *HybridSearchSuite) TestHybridSearchSingleSubReq() {
})
s.NoError(err)
s.NoError(merr.Error(loadStatus))
s.WaitForLoad(ctx, collectionName)
// search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)

View File

@ -0,0 +1,542 @@
package partialsearch
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
)
const (
dim = 128
dbName = ""
timeout = 10 * time.Second
)
type PartialSearchTestSuit struct {
integration.MiniClusterSuite
}
func (s *PartialSearchTestSuit) SetupSuite() {
paramtable.Init()
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "10000")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "10000")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.TaskExecutionCap.Key, "1")
// make query survive when delegator is down
paramtable.Get().Save(paramtable.Get().ProxyCfg.RetryTimesOnReplica.Key, "10")
s.Require().NoError(s.SetupEmbedEtcd())
}
func (s *PartialSearchTestSuit) initCollection(collectionName string, replica int, channelNum int, segmentNum int, segmentRowNum int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.CreateCollectionWithConfiguration(ctx, &integration.CreateCollectionConfig{
DBName: dbName,
Dim: dim,
CollectionName: collectionName,
ChannelNum: channelNum,
SegmentNum: segmentNum,
RowNumPerSegment: segmentRowNum,
})
for i := 1; i < replica; i++ {
s.Cluster.AddQueryNode()
}
// load
loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
ReplicaNumber: int32(replica),
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.True(merr.Ok(loadStatus))
s.WaitForLoad(ctx, collectionName)
log.Info("initCollection Done")
}
func (s *PartialSearchTestSuit) executeQuery(collection string) (int, error) {
ctx := context.Background()
queryResult, err := s.Cluster.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: "",
CollectionName: collection,
Expr: "",
OutputFields: []string{"count(*)"},
})
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
log.Info("query failed", zap.Error(err))
return 0, err
}
return int(queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]), nil
}
// expected return partial result, no search failures
func (s *PartialSearchTestSuit) TestSingleNodeDownOnSingleReplica() {
partialResultRequiredDataRatio := 0.3
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 6 querynode
for i := 1; i < 6; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 2 channels, 6 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 6
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0)) // must fix this cornor case
// stop qn in single replica expected got search failures
s.Cluster.QueryNode.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// expected some search failures, but partial result is returned before all data is loaded
// for case which all querynode down, partial search can decrease recovery time
func (s *PartialSearchTestSuit) TestAllNodeDownOnSingleReplica() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 10 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 10
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
partialResultRecoverTs := atomic.NewInt64(0)
fullResultRecoverTs := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if failCounter.Load() > 0 {
if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
partialResultRecoverTs.Store(time.Now().UnixNano())
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
} else {
log.Info("query return full result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
fullResultRecoverTs.Store(time.Now().UnixNano())
}
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
// stop all qn in single replica expected got search failures
for _, qn := range s.Cluster.GetAllQueryNodes() {
qn.Stop()
}
s.Cluster.AddQueryNode()
time.Sleep(20 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() >= 0)
log.Info("partialResultRecoverTs", zap.Int64("partialResultRecoverTs", partialResultRecoverTs.Load()), zap.Int64("fullResultRecoverTs", fullResultRecoverTs.Load()))
s.True(partialResultRecoverTs.Load() < fullResultRecoverTs.Load())
close(stopSearchCh)
wg.Wait()
}
// expected return full result, no search failures
// cause we won't pick best replica to response query, there may return partial result even when only one replica is partial loaded
func (s *PartialSearchTestSuit) TestSingleNodeDownOnMultiReplica() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 4 querynode
qn1 := s.Cluster.AddQueryNode()
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
// stop qn in single replica expected got search failures
qn1.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// expected return partial result, no search failures
func (s *PartialSearchTestSuit) TestEachReplicaHasNodeDownOnMultiReplica() {
partialResultRequiredDataRatio := 0.3
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 12 querynode
for i := 2; i < 12; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 2 channels, 18 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 9
segmentRowNum := 2000
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
ctx := context.Background()
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
continue
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
replicaResp, err := s.Cluster.Proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{
DbName: dbName,
CollectionName: name,
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, replicaResp.GetStatus().GetErrorCode())
s.Equal(2, len(replicaResp.GetReplicas()))
// for each replica, choose a querynode to stop
for _, replica := range replicaResp.GetReplicas() {
for _, qn := range s.Cluster.GetAllQueryNodes() {
if funcutil.SliceContain(replica.GetNodeIds(), qn.GetQueryNode().GetNodeID()) {
qn.Stop()
break
}
}
}
time.Sleep(10 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// when set high partial result required data ratio, partial search will not be triggered, expected search failures
// for example, set 0.8, but each querynode crash will lost 50% data, so partial search will not be triggered
func (s *PartialSearchTestSuit) TestPartialResultRequiredDataRatioTooHigh() {
partialResultRequiredDataRatio := 0.8
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
qn1 := s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
qn1.Stop()
time.Sleep(10 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() == 0)
close(stopSearchCh)
wg.Wait()
}
// set partial result required data ratio to 0, expected no partial result and no search failures even after all querynode down
func (s *PartialSearchTestSuit) TestSearchNeverFails() {
partialResultRequiredDataRatio := 0.0
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
for _, qn := range s.Cluster.GetAllQueryNodes() {
qn.Stop()
}
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
func (s *PartialSearchTestSuit) TestSkipWaitTSafe() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// mock tsafe Delay
paramtable.Get().Save(paramtable.Get().ProxyCfg.TimeTickInterval.Key, "30000")
// init cluster with 5 querynode
for i := 1; i < 5; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 1 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
s.initCollection(name, 1, 1, 4, 2000)
channelNum := 1
segmentNumInChannel := 4
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
s.Cluster.QueryNode.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
func TestPartialResult(t *testing.T) {
suite.Run(t, new(PartialSearchTestSuit))
}

View File

@ -1,347 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package partialsearch
import (
"context"
"fmt"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
"github.com/milvus-io/milvus/tests/integration"
)
type PartialSearchSuite struct {
integration.MiniClusterSuite
dim int
numCollections int
rowsPerCollection int
waitTimeInSec time.Duration
prefix string
}
func (s *PartialSearchSuite) setupParam() {
s.dim = 128
s.numCollections = 1
s.rowsPerCollection = 100
s.waitTimeInSec = time.Second * 10
}
func (s *PartialSearchSuite) loadCollection(collectionName string, dim int, wg *sync.WaitGroup) {
c := s.Cluster
dbName := ""
schema := integration.ConstructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
s.NoError(err)
showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
batchSize := 500000
for start := 0; start < s.rowsPerCollection; start += batchSize {
rowNum := batchSize
if start+batchSize > s.rowsPerCollection {
rowNum = s.rowsPerCollection - start
}
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.True(merr.Ok(insertResult.GetStatus()))
}
log.Info("=========================Data insertion finished=========================")
// flush
flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName)
log.Info("=========================Data flush finished=========================")
// create index
createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
s.NoError(err)
s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField)
log.Info("=========================Index created=========================")
// load
loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
s.NoError(err)
s.WaitForLoad(context.TODO(), collectionName)
log.Info("=========================Collection loaded=========================")
wg.Done()
}
func (s *PartialSearchSuite) checkCollectionLoaded(collectionName string) bool {
loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{
DbName: "",
CollectionName: collectionName,
})
s.NoError(err)
if loadProgress.GetProgress() != int64(100) {
return false
}
return true
}
func (s *PartialSearchSuite) checkCollectionsLoaded(startCollectionID, endCollectionID int) bool {
notLoaded := 0
loaded := 0
for idx := startCollectionID; idx < endCollectionID; idx++ {
collectionName := s.prefix + "_" + strconv.Itoa(idx)
if s.checkCollectionLoaded(collectionName) {
notLoaded++
} else {
loaded++
}
}
log.Info(fmt.Sprintf("loading status: %d/%d", loaded, endCollectionID-startCollectionID+1))
return notLoaded == 0
}
func (s *PartialSearchSuite) checkAllCollectionsLoaded() bool {
return s.checkCollectionsLoaded(0, s.numCollections)
}
func (s *PartialSearchSuite) search(collectionName string, dim int) {
c := s.Cluster
var err error
// Query
queryReq := &milvuspb.QueryRequest{
Base: nil,
CollectionName: collectionName,
PartitionNames: nil,
Expr: "",
OutputFields: []string{"count(*)"},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
s.Equal(queryResult.Status.ErrorCode, commonpb.ErrorCode_Success)
s.Equal(len(queryResult.FieldsData), 1)
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
s.Equal(numEntities, int64(s.rowsPerCollection))
// Search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
radius := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq)
err = merr.Error(searchResult.GetStatus())
s.NoError(err)
}
func (s *PartialSearchSuite) FailOnSearch(collectionName string) {
c := s.Cluster
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
radius := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal)
searchResult, err := c.Proxy.Search(context.TODO(), searchReq)
s.NoError(err)
err = merr.Error(searchResult.GetStatus())
s.Require().Error(err)
}
func (s *PartialSearchSuite) setupData() {
// Add the second query node
log.Info("=========================Start to inject data=========================")
s.prefix = "TestPartialSearchUtil" + funcutil.GenRandomStr()
searchName := s.prefix + "_0"
wg := sync.WaitGroup{}
for idx := 0; idx < s.numCollections; idx++ {
wg.Add(1)
go s.loadCollection(s.prefix+"_"+strconv.Itoa(idx), s.dim, &wg)
}
wg.Wait()
log.Info("=========================Data injection finished=========================")
s.checkAllCollectionsLoaded()
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
s.search(searchName, s.dim)
log.Info("=========================Search finished=========================")
time.Sleep(s.waitTimeInSec)
s.checkAllCollectionsLoaded()
log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName))
s.search(searchName, s.dim)
log.Info("=========================Search2 finished=========================")
s.checkAllCollectionsReady()
}
func (s *PartialSearchSuite) checkCollectionsReady(startCollectionID, endCollectionID int) {
for i := startCollectionID; i < endCollectionID; i++ {
collectionName := s.prefix + "_" + strconv.Itoa(i)
s.search(collectionName, s.dim)
queryReq := &milvuspb.QueryRequest{
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
}
_, err := s.Cluster.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
}
}
func (s *PartialSearchSuite) checkAllCollectionsReady() {
s.checkCollectionsReady(0, s.numCollections)
}
func (s *PartialSearchSuite) releaseSegmentsReq(collectionID, nodeID, segmentID typeutil.UniqueID, shard string) *querypb.ReleaseSegmentsRequest {
req := &querypb.ReleaseSegmentsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ReleaseSegments),
commonpbutil.WithMsgID(1<<30),
commonpbutil.WithTargetID(nodeID),
),
NodeID: nodeID,
CollectionID: collectionID,
SegmentIDs: []int64{segmentID},
Scope: querypb.DataScope_Historical,
Shard: shard,
NeedTransfer: false,
}
return req
}
func (s *PartialSearchSuite) describeCollection(name string) (int64, []string) {
resp, err := s.Cluster.Proxy.DescribeCollection(context.TODO(), &milvuspb.DescribeCollectionRequest{
DbName: "default",
CollectionName: name,
})
s.NoError(err)
log.Info(fmt.Sprintf("describe collection: %v", resp))
return resp.CollectionID, resp.VirtualChannelNames
}
func (s *PartialSearchSuite) getSegmentIDs(collectionName string) []int64 {
resp, err := s.Cluster.Proxy.GetPersistentSegmentInfo(context.TODO(), &milvuspb.GetPersistentSegmentInfoRequest{
DbName: "default",
CollectionName: collectionName,
})
s.NoError(err)
var res []int64
for _, seg := range resp.Infos {
res = append(res, seg.SegmentID)
}
return res
}
func (s *PartialSearchSuite) TestPartialSearch() {
s.setupParam()
s.setupData()
startCollectionID := 0
endCollectionID := 0
// Search should work in the beginning
s.checkCollectionsReady(startCollectionID, endCollectionID)
// Test case with one segment released
// Partial search does not work yet.
c := s.Cluster
q1 := c.QueryNode
c.MixCoord.StopCheckerForTestOnly()
collectionName := s.prefix + "_0"
nodeID := q1.GetServerIDForTestOnly()
collectionID, channels := s.describeCollection(collectionName)
segs := s.getSegmentIDs(collectionName)
s.Require().Positive(len(segs))
s.Require().Positive(len(channels))
segmentID := segs[0]
shard := channels[0]
req := s.releaseSegmentsReq(collectionID, nodeID, segmentID, shard)
q1.ReleaseSegments(context.TODO(), req)
s.FailOnSearch(collectionName)
c.MixCoord.StartCheckerForTestOnly()
}
func TestPartialSearchUtil(t *testing.T) {
suite.Run(t, new(PartialSearchSuite))
}