From 3e1052f8892ca5af1dbdf01b6a40fac6a486fc2b Mon Sep 17 00:00:00 2001 From: SimFG Date: Tue, 27 Aug 2024 10:28:59 +0800 Subject: [PATCH] enhance: use the msg position obj when getting replicate channel position (#35606) /kind improvement Signed-off-by: SimFG --- internal/proxy/impl.go | 36 +++++++++++++++++++++++------------- internal/proxy/impl_test.go | 22 +++++++++++++++++++++- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 886049dee3..d44c010563 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -6046,10 +6046,9 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate }, nil } var err error - ctxLog := log.Ctx(ctx) if req.GetChannelName() == "" { - ctxLog.Warn("channel name is empty") + log.Ctx(ctx).Warn("channel name is empty") return &milvuspb.ReplicateMessageResponse{ Status: merr.Status(merr.WrapErrParameterInvalidMsg("invalid channel name for the replicate message request")), }, nil @@ -6060,11 +6059,22 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate if req.GetChannelName() == replicateMsgChannel { msgID, err := msgstream.GetChannelLatestMsgID(ctx, node.factory, replicateMsgChannel) if err != nil { - ctxLog.Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err)) + log.Ctx(ctx).Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil } - position := base64.StdEncoding.EncodeToString(msgID) - return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil + position := &msgpb.MsgPosition{ + ChannelName: replicateMsgChannel, + MsgID: msgID, + } + positionBytes, err := proto.Marshal(position) + if err != nil { + log.Ctx(ctx).Warn("failed to marshal position", zap.Error(err)) + return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil + } + return &milvuspb.ReplicateMessageResponse{ + Status: merr.Status(nil), + Position: base64.StdEncoding.EncodeToString(positionBytes), + }, nil } msgPack := &msgstream.MsgPack{ @@ -6079,16 +6089,16 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate header := commonpb.MsgHeader{} err = proto.Unmarshal(msgBytes, &header) if err != nil { - ctxLog.Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err)) + log.Ctx(ctx).Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil } if header.GetBase() == nil { - ctxLog.Warn("msg header base is nil", zap.Int("index", i)) + log.Ctx(ctx).Warn("msg header base is nil", zap.Int("index", i)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil } tsMsg, err := node.replicateStreamManager.GetMsgDispatcher().Unmarshal(msgBytes, header.GetBase().GetMsgType()) if err != nil { - ctxLog.Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err)) + log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil } switch realMsg := tsMsg.(type) { @@ -6096,11 +6106,11 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(), realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs) if err != nil { - ctxLog.Warn("failed to get segment id", zap.Error(err)) + log.Ctx(ctx).Warn("failed to get segment id", zap.Error(err)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil } if len(assignedSegmentInfos) == 0 { - ctxLog.Warn("no segment id assigned") + log.Ctx(ctx).Warn("no segment id assigned") return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrNoAssignSegmentID)}, nil } for assignSegmentID := range assignedSegmentInfos { @@ -6113,19 +6123,19 @@ func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.Replicate msgStream, err := node.replicateStreamManager.GetReplicateMsgStream(ctx, req.ChannelName) if err != nil { - ctxLog.Warn("failed to get msg stream from the replicate stream manager", zap.Error(err)) + log.Ctx(ctx).Warn("failed to get msg stream from the replicate stream manager", zap.Error(err)) return &milvuspb.ReplicateMessageResponse{ Status: merr.Status(err), }, nil } messageIDsMap, err := msgStream.Broadcast(msgPack) if err != nil { - ctxLog.Warn("failed to produce msg", zap.Error(err)) + log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err)) return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil } var position string if len(messageIDsMap[req.GetChannelName()]) == 0 { - ctxLog.Warn("no message id returned") + log.Ctx(ctx).Warn("no message id returned") } else { messageIDs := messageIDsMap[req.GetChannelName()] position = base64.StdEncoding.EncodeToString(messageIDs[len(messageIDs)-1].Serialize()) diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index 6104b99be5..70171fe3b8 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -29,6 +29,7 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -1376,6 +1377,21 @@ func TestProxy_ReplicateMessage(t *testing.T) { }) t.Run("get latest position", func(t *testing.T) { + base64DecodeMsgPosition := func(position string) (*msgstream.MsgPosition, error) { + decodeBytes, err := base64.StdEncoding.DecodeString(position) + if err != nil { + log.Warn("fail to decode the position", zap.Error(err)) + return nil, err + } + msgPosition := &msgstream.MsgPosition{} + err = proto.Unmarshal(decodeBytes, msgPosition) + if err != nil { + log.Warn("fail to unmarshal the position", zap.Error(err)) + return nil, err + } + return msgPosition, nil + } + paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false") defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true") @@ -1397,7 +1413,11 @@ func TestProxy_ReplicateMessage(t *testing.T) { }) assert.NoError(t, err) assert.EqualValues(t, 0, resp.GetStatus().GetCode()) - assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock")), resp.GetPosition()) + { + p, err := base64DecodeMsgPosition(resp.GetPosition()) + assert.NoError(t, err) + assert.Equal(t, []byte("mock"), p.MsgID) + } factory.EXPECT().NewMsgStream(mock.Anything).Return(nil, errors.New("mock")).Once() resp, err = node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{