mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: keep consistent of memory and meta of broadcaster (#39721)
issue: #38399 pr: #39720 Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
4a8e6fc59c
commit
b936061ec6
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
@ -18,15 +19,10 @@ import (
|
||||
func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask) *broadcastTask {
|
||||
msg := message.NewBroadcastMutableMessageBeforeAppend(proto.Message.Payload, proto.Message.Properties)
|
||||
bh := msg.BroadcastHeader()
|
||||
ackedCount := 0
|
||||
for _, acked := range proto.AckedVchannelBitmap {
|
||||
ackedCount += int(acked)
|
||||
}
|
||||
return &broadcastTask{
|
||||
mu: sync.Mutex{},
|
||||
header: bh,
|
||||
task: proto,
|
||||
ackedCount: ackedCount,
|
||||
recoverPersisted: true, // the task is recovered from the recovery info, so it's persisted.
|
||||
}
|
||||
}
|
||||
@ -43,7 +39,6 @@ func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage) *
|
||||
State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING,
|
||||
AckedVchannelBitmap: make([]byte, len(header.VChannels)),
|
||||
},
|
||||
ackedCount: 0,
|
||||
recoverPersisted: false,
|
||||
}
|
||||
}
|
||||
@ -51,11 +46,9 @@ func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage) *
|
||||
// broadcastTask is the state of the broadcast task.
|
||||
type broadcastTask struct {
|
||||
log.Binder
|
||||
mu sync.Mutex
|
||||
header *message.BroadcastHeader
|
||||
task *streamingpb.BroadcastTask
|
||||
ackedCount int // the count of the acked vchannels, the idompotenace is promised by task's bitmap.
|
||||
// always keep same with the positive counter of task's acked_bitmap.
|
||||
mu sync.Mutex
|
||||
header *message.BroadcastHeader
|
||||
task *streamingpb.BroadcastTask
|
||||
recoverPersisted bool // a flag to indicate that the task has been persisted into the recovery info and can be recovered.
|
||||
}
|
||||
|
||||
@ -80,10 +73,6 @@ func (b *broadcastTask) PendingBroadcastMessages() []message.MutableMessage {
|
||||
|
||||
msg := message.NewBroadcastMutableMessageBeforeAppend(b.task.Message.Payload, b.task.Message.Properties)
|
||||
msgs := msg.SplitIntoMutableMessage()
|
||||
// If there's no vchannel acked, return all the messages directly.
|
||||
if b.ackedCount == 0 {
|
||||
return msgs
|
||||
}
|
||||
// filter out the vchannel that has been acked.
|
||||
pendingMessages := make([]message.MutableMessage, 0, len(msgs))
|
||||
for i, msg := range msgs {
|
||||
@ -103,7 +92,7 @@ func (b *broadcastTask) InitializeRecovery(ctx context.Context) error {
|
||||
if b.recoverPersisted {
|
||||
return nil
|
||||
}
|
||||
if err := b.saveTask(ctx, b.Logger()); err != nil {
|
||||
if err := b.saveTask(ctx, b.task, b.Logger()); err != nil {
|
||||
return err
|
||||
}
|
||||
b.recoverPersisted = true
|
||||
@ -115,36 +104,44 @@ func (b *broadcastTask) Ack(ctx context.Context, vchannel string) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
b.setVChannelAcked(vchannel)
|
||||
if b.isAllDone() {
|
||||
// All vchannels are acked, mark the task as done, even if there are still pending messages on working.
|
||||
// The pending messages is repeated sent operation, can be ignored.
|
||||
b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE
|
||||
task, ok := b.copyAndSetVChannelAcked(vchannel)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We should always save the task after acked.
|
||||
// Even if the task mark as done in memory.
|
||||
// Because the task is set as done in memory before save the recovery info.
|
||||
return b.saveTask(ctx, b.Logger().With(zap.String("ackVChannel", vchannel)))
|
||||
if err := b.saveTask(ctx, task, b.Logger().With(zap.String("ackVChannel", vchannel))); err != nil {
|
||||
return err
|
||||
}
|
||||
b.task = task
|
||||
return nil
|
||||
}
|
||||
|
||||
// setVChannelAcked sets the vchannel as acked.
|
||||
func (b *broadcastTask) setVChannelAcked(vchannel string) {
|
||||
idx, err := b.findIdxOfVChannel(vchannel)
|
||||
// copyAndSetVChannelAcked copies the task and set the vchannel as acked.
|
||||
// if the vchannel is already acked, it returns nil and false.
|
||||
func (b *broadcastTask) copyAndSetVChannelAcked(vchannel string) (*streamingpb.BroadcastTask, bool) {
|
||||
task := proto.Clone(b.task).(*streamingpb.BroadcastTask)
|
||||
idx, err := findIdxOfVChannel(vchannel, b.Header().VChannels)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b.task.AckedVchannelBitmap[idx] = 1
|
||||
// Check if all vchannels are acked.
|
||||
ackedCount := 0
|
||||
for _, acked := range b.task.AckedVchannelBitmap {
|
||||
ackedCount += int(acked)
|
||||
if task.AckedVchannelBitmap[idx] != 0 {
|
||||
return nil, false
|
||||
}
|
||||
b.ackedCount = ackedCount
|
||||
task.AckedVchannelBitmap[idx] = 1
|
||||
if isAllDone(task) {
|
||||
// All vchannels are acked, mark the task as done, even if there are still pending messages on working.
|
||||
// The pending messages is repeated sent operation, can be ignored.
|
||||
task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE
|
||||
}
|
||||
return task, true
|
||||
}
|
||||
|
||||
// findIdxOfVChannel finds the index of the vchannel in the broadcast task.
|
||||
func (b *broadcastTask) findIdxOfVChannel(vchannel string) (int, error) {
|
||||
for i, channelName := range b.header.VChannels {
|
||||
func findIdxOfVChannel(vchannel string, vchannels []string) (int, error) {
|
||||
for i, channelName := range vchannels {
|
||||
if channelName == vchannel {
|
||||
return i, nil
|
||||
}
|
||||
@ -152,44 +149,74 @@ func (b *broadcastTask) findIdxOfVChannel(vchannel string) (int, error) {
|
||||
return -1, errors.Errorf("unreachable: vchannel is %s not found in the broadcast task", vchannel)
|
||||
}
|
||||
|
||||
// isAllDone check if all the vchannels are acked.
|
||||
func (b *broadcastTask) isAllDone() bool {
|
||||
return b.ackedCount == len(b.header.VChannels)
|
||||
}
|
||||
|
||||
// BroadcastDone marks the broadcast operation is done.
|
||||
func (b *broadcastTask) BroadcastDone(ctx context.Context) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.isAllDone() {
|
||||
task := b.copyAndMarkBroadcastDone()
|
||||
if err := b.saveTask(ctx, task, b.Logger()); err != nil {
|
||||
return err
|
||||
}
|
||||
b.task = task
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyAndMarkBroadcastDone copies the task and mark the broadcast task as done.
|
||||
func (b *broadcastTask) copyAndMarkBroadcastDone() *streamingpb.BroadcastTask {
|
||||
task := proto.Clone(b.task).(*streamingpb.BroadcastTask)
|
||||
if isAllDone(task) {
|
||||
// If all vchannels are acked, mark the task as done.
|
||||
b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE
|
||||
task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE
|
||||
} else {
|
||||
// There's no more pending message, mark the task as wait ack.
|
||||
b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK
|
||||
task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK
|
||||
}
|
||||
return b.saveTask(ctx, b.Logger())
|
||||
return task
|
||||
}
|
||||
|
||||
// IsAllAcked returns true if all the vchannels are acked.
|
||||
func (b *broadcastTask) IsAllAcked() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.isAllDone()
|
||||
return isAllDone(b.task)
|
||||
}
|
||||
|
||||
// isAllDone check if all the vchannels are acked.
|
||||
func isAllDone(task *streamingpb.BroadcastTask) bool {
|
||||
for _, acked := range task.AckedVchannelBitmap {
|
||||
if acked == 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ackedCount returns the count of the acked vchannels.
|
||||
func ackedCount(task *streamingpb.BroadcastTask) int {
|
||||
count := 0
|
||||
for _, acked := range task.AckedVchannelBitmap {
|
||||
count += int(acked)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// IsAcked returns true if any vchannel is acked.
|
||||
func (b *broadcastTask) IsAcked() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.ackedCount > 0
|
||||
for _, acked := range b.task.AckedVchannelBitmap {
|
||||
if acked != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// saveTask saves the broadcast task recovery info.
|
||||
func (b *broadcastTask) saveTask(ctx context.Context, logger *log.MLogger) error {
|
||||
logger = logger.With(zap.String("state", b.task.State.String()), zap.Int("ackedVChannelCount", b.ackedCount))
|
||||
if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.header.BroadcastID, b.task); err != nil {
|
||||
func (b *broadcastTask) saveTask(ctx context.Context, task *streamingpb.BroadcastTask, logger *log.MLogger) error {
|
||||
logger = logger.With(zap.String("state", task.State.String()), zap.Int("ackedVChannelCount", ackedCount(task)))
|
||||
if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.header.BroadcastID, task); err != nil {
|
||||
logger.Warn("save broadcast task failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user