Fix current target may be updated to an invalid target (#21742)

Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
yah01 2023-01-17 11:41:51 +08:00 committed by GitHub
parent 81326ca0a1
commit c8f89907b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 31 deletions

View File

@ -37,6 +37,7 @@ type CollectionObserver struct {
dist *meta.DistributionManager dist *meta.DistributionManager
meta *meta.Meta meta *meta.Meta
targetMgr *meta.TargetManager targetMgr *meta.TargetManager
targetObserver *TargetObserver
collectionLoadedCount map[int64]int collectionLoadedCount map[int64]int
partitionLoadedCount map[int64]int partitionLoadedCount map[int64]int
@ -47,12 +48,14 @@ func NewCollectionObserver(
dist *meta.DistributionManager, dist *meta.DistributionManager,
meta *meta.Meta, meta *meta.Meta,
targetMgr *meta.TargetManager, targetMgr *meta.TargetManager,
targetObserver *TargetObserver,
) *CollectionObserver { ) *CollectionObserver {
return &CollectionObserver{ return &CollectionObserver{
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
dist: dist, dist: dist,
meta: meta, meta: meta,
targetMgr: targetMgr, targetMgr: targetMgr,
targetObserver: targetObserver,
collectionLoadedCount: make(map[int64]int), collectionLoadedCount: make(map[int64]int),
partitionLoadedCount: make(map[int64]int), partitionLoadedCount: make(map[int64]int),
} }
@ -201,9 +204,8 @@ func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Colle
return return
} }
ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount
if updated.LoadPercentage == 100 { if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) {
delete(ob.collectionLoadedCount, collection.GetCollectionID()) delete(ob.collectionLoadedCount, collection.GetCollectionID())
ob.targetMgr.UpdateCollectionCurrentTarget(updated.CollectionID)
updated.Status = querypb.LoadStatus_Loaded updated.Status = querypb.LoadStatus_Loaded
ob.meta.CollectionManager.UpdateCollection(updated) ob.meta.CollectionManager.UpdateCollection(updated)
@ -265,9 +267,8 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti
return return
} }
ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount ob.partitionLoadedCount[partition.GetPartitionID()] = loadedCount
if updated.LoadPercentage == 100 { if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) {
delete(ob.partitionLoadedCount, partition.GetPartitionID()) delete(ob.partitionLoadedCount, partition.GetPartitionID())
ob.targetMgr.UpdateCollectionCurrentTarget(partition.GetCollectionID(), partition.GetPartitionID())
updated.Status = querypb.LoadStatus_Loaded updated.Status = querypb.LoadStatus_Loaded
ob.meta.CollectionManager.PutPartition(updated) ob.meta.CollectionManager.PutPartition(updated)

View File

@ -59,6 +59,7 @@ type CollectionObserverSuite struct {
dist *meta.DistributionManager dist *meta.DistributionManager
meta *meta.Meta meta *meta.Meta
targetMgr *meta.TargetManager targetMgr *meta.TargetManager
targetObserver *TargetObserver
// Test object // Test object
ob *CollectionObserver ob *CollectionObserver
@ -180,18 +181,30 @@ func (suite *CollectionObserverSuite) SetupTest() {
suite.meta = meta.NewMeta(suite.idAllocator, suite.store) suite.meta = meta.NewMeta(suite.idAllocator, suite.store)
suite.broker = meta.NewMockBroker(suite.T()) suite.broker = meta.NewMockBroker(suite.T())
suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta)
suite.targetObserver = NewTargetObserver(suite.meta,
suite.targetMgr,
suite.dist,
suite.broker,
)
// Test object // Test object
suite.ob = NewCollectionObserver( suite.ob = NewCollectionObserver(
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
) )
for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
}
suite.targetObserver.Start(context.Background())
suite.loadAll() suite.loadAll()
} }
func (suite *CollectionObserverSuite) TearDownTest() { func (suite *CollectionObserverSuite) TearDownTest() {
suite.targetObserver.Stop()
suite.ob.Stop() suite.ob.Stop()
suite.kv.Close() suite.kv.Close()
} }

View File

@ -30,6 +30,11 @@ import (
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
type checkRequest struct {
CollectionID int64
Notifier chan bool
}
type targetUpdateRequest struct { type targetUpdateRequest struct {
CollectionID int64 CollectionID int64
Notifier chan error Notifier chan error
@ -44,6 +49,7 @@ type TargetObserver struct {
distMgr *meta.DistributionManager distMgr *meta.DistributionManager
broker meta.Broker broker meta.Broker
manualCheck chan checkRequest
nextTargetLastUpdate map[int64]time.Time nextTargetLastUpdate map[int64]time.Time
updateChan chan targetUpdateRequest updateChan chan targetUpdateRequest
mut sync.Mutex // Guard readyNotifiers mut sync.Mutex // Guard readyNotifiers
@ -59,6 +65,7 @@ func NewTargetObserver(meta *meta.Meta, targetMgr *meta.TargetManager, distMgr *
targetMgr: targetMgr, targetMgr: targetMgr,
distMgr: distMgr, distMgr: distMgr,
broker: broker, broker: broker,
manualCheck: make(chan checkRequest, 10),
nextTargetLastUpdate: make(map[int64]time.Time), nextTargetLastUpdate: make(map[int64]time.Time),
updateChan: make(chan targetUpdateRequest), updateChan: make(chan targetUpdateRequest),
readyNotifiers: make(map[int64][]chan struct{}), readyNotifiers: make(map[int64][]chan struct{}),
@ -95,21 +102,48 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
ob.clean() ob.clean()
ob.tryUpdateTarget() ob.tryUpdateTarget()
case request := <-ob.updateChan: case req := <-ob.manualCheck:
err := ob.updateNextTarget(request.CollectionID) ob.check(req.CollectionID)
req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID)
case req := <-ob.updateChan:
err := ob.updateNextTarget(req.CollectionID)
if err != nil { if err != nil {
close(request.ReadyNotifier) close(req.ReadyNotifier)
} else { } else {
ob.mut.Lock() ob.mut.Lock()
ob.readyNotifiers[request.CollectionID] = append(ob.readyNotifiers[request.CollectionID], request.ReadyNotifier) ob.readyNotifiers[req.CollectionID] = append(ob.readyNotifiers[req.CollectionID], req.ReadyNotifier)
ob.mut.Unlock() ob.mut.Unlock()
} }
request.Notifier <- err req.Notifier <- err
} }
} }
} }
// Check checks whether the next target is ready,
// and updates the current target if it is,
// returns true if current target is not nil
func (ob *TargetObserver) Check(collectionID int64) bool {
notifier := make(chan bool)
ob.manualCheck <- checkRequest{
CollectionID: collectionID,
Notifier: notifier,
}
return <-notifier
}
func (ob *TargetObserver) check(collectionID int64) {
if ob.shouldUpdateCurrentTarget(collectionID) {
ob.updateCurrentTarget(collectionID)
}
if ob.shouldUpdateNextTarget(collectionID) {
// update next target in collection level
ob.updateNextTarget(collectionID)
}
}
// UpdateNextTarget updates the next target, // UpdateNextTarget updates the next target,
// returns a channel which will be closed when the next target is ready, // returns a channel which will be closed when the next target is ready,
// or returns error if failed to pull target // or returns error if failed to pull target
@ -138,14 +172,7 @@ func (ob *TargetObserver) ReleaseCollection(collectionID int64) {
func (ob *TargetObserver) tryUpdateTarget() { func (ob *TargetObserver) tryUpdateTarget() {
collections := ob.meta.GetAll() collections := ob.meta.GetAll()
for _, collectionID := range collections { for _, collectionID := range collections {
if ob.shouldUpdateCurrentTarget(collectionID) { ob.check(collectionID)
ob.updateCurrentTarget(collectionID)
}
if ob.shouldUpdateNextTarget(collectionID) {
// update next target in collection level
ob.updateNextTarget(collectionID)
}
} }
collectionSet := typeutil.NewUniqueSet(collections...) collectionSet := typeutil.NewUniqueSet(collections...)
@ -199,12 +226,6 @@ func (ob *TargetObserver) updateNextTargetTimestamp(collectionID int64) {
} }
func (ob *TargetObserver) shouldUpdateCurrentTarget(collectionID int64) bool { func (ob *TargetObserver) shouldUpdateCurrentTarget(collectionID int64) bool {
// Collection observer will update the current target as loading done,
// avoid double updating, which will cause update current target to a unfinished next target
if !ob.targetMgr.IsCurrentTargetExist(collectionID) {
return false
}
replicaNum := ob.meta.CollectionManager.GetReplicaNumber(collectionID) replicaNum := ob.meta.CollectionManager.GetReplicaNumber(collectionID)
// check channel first // check channel first

View File

@ -279,11 +279,6 @@ func (s *Server) initMeta() error {
func (s *Server) initObserver() { func (s *Server) initObserver() {
log.Info("init observers") log.Info("init observers")
s.collectionObserver = observers.NewCollectionObserver(
s.dist,
s.meta,
s.targetMgr,
)
s.leaderObserver = observers.NewLeaderObserver( s.leaderObserver = observers.NewLeaderObserver(
s.dist, s.dist,
s.meta, s.meta,
@ -296,6 +291,12 @@ func (s *Server) initObserver() {
s.dist, s.dist,
s.broker, s.broker,
) )
s.collectionObserver = observers.NewCollectionObserver(
s.dist,
s.meta,
s.targetMgr,
s.targetObserver,
)
} }
func (s *Server) afterStart() { func (s *Server) afterStart() {

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/dist" "github.com/milvus-io/milvus/internal/querycoordv2/dist"
"github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/mocks" "github.com/milvus-io/milvus/internal/querycoordv2/mocks"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/task"
@ -401,6 +402,18 @@ func (suite *ServerSuite) hackServer() {
suite.server.balancer, suite.server.balancer,
suite.server.taskScheduler, suite.server.taskScheduler,
) )
suite.server.targetObserver = observers.NewTargetObserver(
suite.server.meta,
suite.server.targetMgr,
suite.server.dist,
suite.broker,
)
suite.server.collectionObserver = observers.NewCollectionObserver(
suite.server.dist,
suite.server.meta,
suite.server.targetMgr,
suite.server.targetObserver,
)
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe() suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe()
for _, collection := range suite.collections { for _, collection := range suite.collections {