milvus/tests/integration/partialsearch/partial_result_on_node_down_test.go
wei liu 54619eaa2c
feat: Implement partial result support on node down (#42009)
issue: https://github.com/milvus-io/milvus/issues/41690
This commit implements partial search result functionality when query
nodes go down, improving system availability during node failures. The
changes include:

- Enhanced load balancing in proxy (lb_policy.go) to handle node
failures with retry support
- Added partial search result capability in querynode delegator and
distribution logic
- Implemented tests for various partial result scenarios when nodes go
down
- Added metrics to track partial search results in querynode_metrics.go
- Updated parameter configuration to support partial result required
data ratio
- Replaced old partial_search_test.go with more comprehensive
partial_result_on_node_down_test.go
- Updated proto definitions and improved retry logic

These changes improve query resilience by returning partial results to
users when some query nodes are unavailable, ensuring that queries don't
completely fail when a portion of data remains accessible.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
2025-05-28 00:12:28 +08:00

543 lines
18 KiB
Go

package partialsearch
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
)
const (
dim = 128
dbName = ""
timeout = 10 * time.Second
)
type PartialSearchTestSuit struct {
integration.MiniClusterSuite
}
func (s *PartialSearchTestSuit) SetupSuite() {
paramtable.Init()
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "10000")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "10000")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.TaskExecutionCap.Key, "1")
// make query survive when delegator is down
paramtable.Get().Save(paramtable.Get().ProxyCfg.RetryTimesOnReplica.Key, "10")
s.Require().NoError(s.SetupEmbedEtcd())
}
func (s *PartialSearchTestSuit) initCollection(collectionName string, replica int, channelNum int, segmentNum int, segmentRowNum int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.CreateCollectionWithConfiguration(ctx, &integration.CreateCollectionConfig{
DBName: dbName,
Dim: dim,
CollectionName: collectionName,
ChannelNum: channelNum,
SegmentNum: segmentNum,
RowNumPerSegment: segmentRowNum,
})
for i := 1; i < replica; i++ {
s.Cluster.AddQueryNode()
}
// load
loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
ReplicaNumber: int32(replica),
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.True(merr.Ok(loadStatus))
s.WaitForLoad(ctx, collectionName)
log.Info("initCollection Done")
}
func (s *PartialSearchTestSuit) executeQuery(collection string) (int, error) {
ctx := context.Background()
queryResult, err := s.Cluster.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: "",
CollectionName: collection,
Expr: "",
OutputFields: []string{"count(*)"},
})
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
log.Info("query failed", zap.Error(err))
return 0, err
}
return int(queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]), nil
}
// expected return partial result, no search failures
func (s *PartialSearchTestSuit) TestSingleNodeDownOnSingleReplica() {
partialResultRequiredDataRatio := 0.3
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 6 querynode
for i := 1; i < 6; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 2 channels, 6 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 6
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0)) // must fix this cornor case
// stop qn in single replica expected got search failures
s.Cluster.QueryNode.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// expected some search failures, but partial result is returned before all data is loaded
// for case which all querynode down, partial search can decrease recovery time
func (s *PartialSearchTestSuit) TestAllNodeDownOnSingleReplica() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 10 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 10
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
partialResultRecoverTs := atomic.NewInt64(0)
fullResultRecoverTs := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if failCounter.Load() > 0 {
if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
partialResultRecoverTs.Store(time.Now().UnixNano())
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
} else {
log.Info("query return full result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
fullResultRecoverTs.Store(time.Now().UnixNano())
}
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
// stop all qn in single replica expected got search failures
for _, qn := range s.Cluster.GetAllQueryNodes() {
qn.Stop()
}
s.Cluster.AddQueryNode()
time.Sleep(20 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() >= 0)
log.Info("partialResultRecoverTs", zap.Int64("partialResultRecoverTs", partialResultRecoverTs.Load()), zap.Int64("fullResultRecoverTs", fullResultRecoverTs.Load()))
s.True(partialResultRecoverTs.Load() < fullResultRecoverTs.Load())
close(stopSearchCh)
wg.Wait()
}
// expected return full result, no search failures
// cause we won't pick best replica to response query, there may return partial result even when only one replica is partial loaded
func (s *PartialSearchTestSuit) TestSingleNodeDownOnMultiReplica() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 4 querynode
qn1 := s.Cluster.AddQueryNode()
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
// stop qn in single replica expected got search failures
qn1.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// expected return partial result, no search failures
func (s *PartialSearchTestSuit) TestEachReplicaHasNodeDownOnMultiReplica() {
partialResultRequiredDataRatio := 0.3
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 12 querynode
for i := 2; i < 12; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 2 channels, 18 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 9
segmentRowNum := 2000
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
ctx := context.Background()
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
continue
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
replicaResp, err := s.Cluster.Proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{
DbName: dbName,
CollectionName: name,
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, replicaResp.GetStatus().GetErrorCode())
s.Equal(2, len(replicaResp.GetReplicas()))
// for each replica, choose a querynode to stop
for _, replica := range replicaResp.GetReplicas() {
for _, qn := range s.Cluster.GetAllQueryNodes() {
if funcutil.SliceContain(replica.GetNodeIds(), qn.GetQueryNode().GetNodeID()) {
qn.Stop()
break
}
}
}
time.Sleep(10 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// when set high partial result required data ratio, partial search will not be triggered, expected search failures
// for example, set 0.8, but each querynode crash will lost 50% data, so partial search will not be triggered
func (s *PartialSearchTestSuit) TestPartialResultRequiredDataRatioTooHigh() {
partialResultRequiredDataRatio := 0.8
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
qn1 := s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
qn1.Stop()
time.Sleep(10 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() == 0)
close(stopSearchCh)
wg.Wait()
}
// set partial result required data ratio to 0, expected no partial result and no search failures even after all querynode down
func (s *PartialSearchTestSuit) TestSearchNeverFails() {
partialResultRequiredDataRatio := 0.0
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
for _, qn := range s.Cluster.GetAllQueryNodes() {
qn.Stop()
}
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
func (s *PartialSearchTestSuit) TestSkipWaitTSafe() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// mock tsafe Delay
paramtable.Get().Save(paramtable.Get().ProxyCfg.TimeTickInterval.Key, "30000")
// init cluster with 5 querynode
for i := 1; i < 5; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 1 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
s.initCollection(name, 1, 1, 4, 2000)
channelNum := 1
segmentNumInChannel := 4
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
s.Cluster.QueryNode.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
func TestPartialResult(t *testing.T) {
suite.Run(t, new(PartialSearchTestSuit))
}