From de8f0af20d3ed283f69590c7f87c81798811cde4 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Tue, 6 May 2025 01:12:53 +0800 Subject: [PATCH] enhance: use dispatcher at delegator when enable streaming (#41266) issue: #38399 - add an adaptor type to adapt the streaming service client and msgstream client to reuse the msgdispatcher. Signed-off-by: chyezh --- .../streaming/msgstream_adaptor.go | 132 +++++++++++++ .../streaming/msgstream_adaptor_test.go | 174 ++++++++++++++++++ internal/distributed/streaming/streaming.go | 5 + internal/distributed/streaming/wal.go | 15 +- internal/distributed/streaming/wal_test.go | 24 ++- .../querynodev2/delegator/delegator_data.go | 30 +-- internal/querynodev2/server.go | 8 +- .../flusherimpl/data_service_wrapper.go | 4 +- internal/util/pipeline/stream_pipeline.go | 29 --- pkg/mq/msgstream/msgstream_util.go | 17 ++ pkg/streaming/util/message/adaptor/handler.go | 16 +- tests/go_client/testcases/collection_test.go | 4 +- .../testcases/test_partition_key_isolation.py | 9 +- .../testcases/test_phrase_match.py | 11 +- 14 files changed, 393 insertions(+), 85 deletions(-) create mode 100644 internal/distributed/streaming/msgstream_adaptor.go create mode 100644 internal/distributed/streaming/msgstream_adaptor_test.go diff --git a/internal/distributed/streaming/msgstream_adaptor.go b/internal/distributed/streaming/msgstream_adaptor.go new file mode 100644 index 0000000000..638551f163 --- /dev/null +++ b/internal/distributed/streaming/msgstream_adaptor.go @@ -0,0 +1,132 @@ +package streaming + +import ( + "context" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/mq/common" + "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/options" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" +) + +var ( + _ msgstream.Factory = (*delegatorMsgstreamFactory)(nil) + _ msgstream.MsgStream = (*delegatorMsgstreamAdaptor)(nil) +) + +// NewDelegatorMsgstreamFactory returns a streaming-based msgstream factory for delegator. +func NewDelegatorMsgstreamFactory() msgstream.Factory { + return &delegatorMsgstreamFactory{} +} + +// Only for delegator. +type delegatorMsgstreamFactory struct{} + +func (f *delegatorMsgstreamFactory) NewMsgStream(ctx context.Context) (msgstream.MsgStream, error) { + panic("should never be called") +} + +func (f *delegatorMsgstreamFactory) NewTtMsgStream(ctx context.Context) (msgstream.MsgStream, error) { + return &delegatorMsgstreamAdaptor{}, nil +} + +func (f *delegatorMsgstreamFactory) NewMsgStreamDisposer(ctx context.Context) func([]string, string) error { + panic("should never be called") +} + +// Only for delegator. +type delegatorMsgstreamAdaptor struct { + scanner Scanner + ch <-chan *msgstream.ConsumeMsgPack +} + +func (m *delegatorMsgstreamAdaptor) Close() { + if m.scanner != nil { + m.scanner.Close() + } +} + +func (m *delegatorMsgstreamAdaptor) AsProducer(ctx context.Context, channels []string) { + panic("should never be called") +} + +func (m *delegatorMsgstreamAdaptor) Produce(context.Context, *msgstream.MsgPack) error { + panic("should never be called") +} + +func (m *delegatorMsgstreamAdaptor) SetRepackFunc(repackFunc msgstream.RepackFunc) { + panic("should never be called") +} + +func (m *delegatorMsgstreamAdaptor) GetProduceChannels() []string { + panic("should never be called") +} + +func (m *delegatorMsgstreamAdaptor) Broadcast(context.Context, *msgstream.MsgPack) (map[string][]msgstream.MessageID, error) { + panic("should never be called") +} + +func (m *delegatorMsgstreamAdaptor) AsConsumer(ctx context.Context, channels []string, subName string, position common.SubscriptionInitialPosition) error { + // always ignored. + if position != common.SubscriptionPositionUnknown { + panic("should never be called") + } + return nil +} + +func (m *delegatorMsgstreamAdaptor) Chan() <-chan *msgstream.ConsumeMsgPack { + if m.ch == nil { + panic("should never be called if seek is not done") + } + return m.ch +} + +func (m *delegatorMsgstreamAdaptor) GetUnmarshalDispatcher() msgstream.UnmarshalDispatcher { + return adaptor.UnmashalerDispatcher +} + +func (m *delegatorMsgstreamAdaptor) Seek(ctx context.Context, msgPositions []*msgstream.MsgPosition, includeCurrentMsg bool) error { + if len(msgPositions) != 1 { + panic("should never be called if len(msgPositions) is not 1") + } + position := msgPositions[0] + startFrom := adaptor.MustGetMessageIDFromMQWrapperIDBytes(WAL().WALName(), position.MsgID) + log.Info( + "delegator msgstream adaptor seeks from position with scanner", + zap.String("channel", position.GetChannelName()), + zap.Any("startFromMessageID", startFrom), + zap.Uint64("timestamp", position.GetTimestamp()), + ) + handler := adaptor.NewMsgPackAdaptorHandler() + pchannel := funcutil.ToPhysicalChannel(position.GetChannelName()) + m.scanner = WAL().Read(ctx, ReadOption{ + PChannel: pchannel, + DeliverPolicy: options.DeliverPolicyStartFrom(startFrom), + DeliverFilters: []options.DeliverFilter{ + // only consume messages with timestamp >= position timestamp + options.DeliverFilterTimeTickGTE(position.GetTimestamp()), + // only consume insert and delete messages + options.DeliverFilterMessageType(message.MessageTypeInsert, message.MessageTypeDelete, message.MessageTypeSchemaChange), + }, + MessageHandler: handler, + }) + m.ch = handler.Chan() + return nil +} + +func (m *delegatorMsgstreamAdaptor) GetLatestMsgID(channel string) (msgstream.MessageID, error) { + panic("should never be called") +} + +func (m *delegatorMsgstreamAdaptor) CheckTopicValid(channel string) error { + panic("should never be called") +} + +func (m *delegatorMsgstreamAdaptor) ForceEnableProduce(can bool) { + panic("should never be called") +} diff --git a/internal/distributed/streaming/msgstream_adaptor_test.go b/internal/distributed/streaming/msgstream_adaptor_test.go new file mode 100644 index 0000000000..c07ecaf195 --- /dev/null +++ b/internal/distributed/streaming/msgstream_adaptor_test.go @@ -0,0 +1,174 @@ +package streaming + +import ( + "context" + "testing" + + "github.com/milvus-io/milvus/pkg/v2/mq/common" + "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" +) + +func TestDelegatorMsgstreamFactory(t *testing.T) { + factory := NewDelegatorMsgstreamFactory() + + // Test NewMsgStream + t.Run("NewMsgStream", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("NewMsgStream should panic but did not") + } + }() + _, _ = factory.NewMsgStream(context.Background()) + }) + + // Test NewTtMsgStream + t.Run("NewTtMsgStream", func(t *testing.T) { + stream, err := factory.NewTtMsgStream(context.Background()) + if err != nil { + t.Errorf("NewTtMsgStream returned an error: %v", err) + } + if stream == nil { + t.Errorf("NewTtMsgStream returned nil stream") + } + }) + + // Test NewMsgStreamDisposer + t.Run("NewMsgStreamDisposer", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("NewMsgStreamDisposer should panic but did not") + } + }() + _ = factory.NewMsgStreamDisposer(context.Background()) + }) +} + +func TestDelegatorMsgstreamAdaptor(t *testing.T) { + adaptor := &delegatorMsgstreamAdaptor{} + + // Test Close + t.Run("Close", func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("Close should not panic but did") + } + }() + adaptor.Close() + }) + + // Test AsProducer + t.Run("AsProducer", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("AsProducer should panic but did not") + } + }() + adaptor.AsProducer(context.Background(), []string{"channel1"}) + }) + + // Test Produce + t.Run("Produce", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Produce should panic but did not") + } + }() + _ = adaptor.Produce(context.Background(), &msgstream.MsgPack{}) + }) + + // Test SetRepackFunc + t.Run("SetRepackFunc", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("SetRepackFunc should panic but did not") + } + }() + adaptor.SetRepackFunc(nil) + }) + + // Test GetProduceChannels + t.Run("GetProduceChannels", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("GetProduceChannels should panic but did not") + } + }() + _ = adaptor.GetProduceChannels() + }) + + // Test Broadcast + t.Run("Broadcast", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Broadcast should panic but did not") + } + }() + _, _ = adaptor.Broadcast(context.Background(), &msgstream.MsgPack{}) + }) + + // Test AsConsumer + t.Run("AsConsumer", func(t *testing.T) { + err := adaptor.AsConsumer(context.Background(), []string{"channel1"}, "subName", common.SubscriptionPositionUnknown) + if err != nil { + t.Errorf("AsConsumer returned an error: %v", err) + } + }) + + // Test Chan + t.Run("Chan", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Seek should panic if len(msgPositions) != 1 but did not") + } + }() + adaptor.Chan() + }) + + // Test GetUnmarshalDispatcher + t.Run("GetUnmarshalDispatcher", func(t *testing.T) { + dispatcher := adaptor.GetUnmarshalDispatcher() + if dispatcher == nil { + t.Errorf("GetUnmarshalDispatcher returned nil") + } + }) + + // Test Seek + t.Run("Seek", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Seek should panic if len(msgPositions) != 1 but did not") + } + }() + _ = adaptor.Seek(context.Background(), []*msgstream.MsgPosition{}, true) + }) + + // Test GetLatestMsgID + t.Run("GetLatestMsgID", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("GetLatestMsgID should panic but did not") + } + }() + _, _ = adaptor.GetLatestMsgID("channel1") + }) + + // Test CheckTopicValid + t.Run("CheckTopicValid", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("CheckTopicValid should panic but did not") + } + }() + _ = adaptor.CheckTopicValid("channel1") + }) + + // Test ForceEnableProduce + t.Run("ForceEnableProduce", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("ForceEnableProduce should panic but did not") + } + }() + adaptor.ForceEnableProduce(true) + }) +} diff --git a/internal/distributed/streaming/streaming.go b/internal/distributed/streaming/streaming.go index a22f563c06..3f9676bf78 100644 --- a/internal/distributed/streaming/streaming.go +++ b/internal/distributed/streaming/streaming.go @@ -53,8 +53,13 @@ type TxnOption struct { } type ReadOption struct { + // PChannel is the target pchannel to read, if the pchannel is not set. + // It will be parsed from setted `VChannel`. + PChannel string + // VChannel is the target vchannel to read. // It must be set to read message from a vchannel. + // If VChannel is empty, the PChannel must be set, and all message of pchannel will be read. VChannel string // DeliverPolicy is the deliver policy of the consumer. diff --git a/internal/distributed/streaming/wal.go b/internal/distributed/streaming/wal.go index a29103adad..5fcc495384 100644 --- a/internal/distributed/streaming/wal.go +++ b/internal/distributed/streaming/wal.go @@ -92,14 +92,20 @@ func (w *walAccesserImpl) Read(_ context.Context, opts ReadOption) Scanner { } defer w.lifetime.Done() - if opts.VChannel == "" { - return newErrScanner(status.NewInvaildArgument("vchannel is required")) + if opts.VChannel == "" && opts.PChannel == "" { + panic("pchannel is required if vchannel is not set") } + if opts.VChannel != "" { + pchannel := funcutil.ToPhysicalChannel(opts.VChannel) + if opts.PChannel != "" && opts.PChannel != pchannel { + panic("pchannel is not match with vchannel") + } + opts.PChannel = pchannel + } // TODO: optimize the consumer into pchannel level. - pchannel := funcutil.ToPhysicalChannel(opts.VChannel) rc := consumer.NewResumableConsumer(w.handlerClient.CreateConsumer, &consumer.ConsumerOptions{ - PChannel: pchannel, + PChannel: opts.PChannel, VChannel: opts.VChannel, DeliverPolicy: opts.DeliverPolicy, DeliverFilters: opts.DeliverFilters, @@ -176,6 +182,7 @@ func (w *walAccesserImpl) Close() { // newErrScanner creates a scanner that returns an error. func newErrScanner(err error) Scanner { ch := make(chan struct{}) + close(ch) return errScanner{ closedCh: ch, err: err, diff --git a/internal/distributed/streaming/wal_test.go b/internal/distributed/streaming/wal_test.go index cccb0967c8..97bc0b1fc6 100644 --- a/internal/distributed/streaming/wal_test.go +++ b/internal/distributed/streaming/wal_test.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/distributed/streaming/internal/producer" "github.com/milvus-io/milvus/internal/mocks/streamingcoord/mock_client" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_consumer" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_producer" "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/mock_handler" "github.com/milvus-io/milvus/internal/util/streamingutil/status" @@ -29,9 +30,14 @@ const ( vChannel3 = "by-dev-rootcoord-dml_3" ) -func TestWAL(t *testing.T) { +func createMockWAL(t *testing.T) ( + *walAccesserImpl, + *mock_client.MockClient, + *mock_client.MockBroadcastService, + *mock_handler.MockHandlerClient, +) { coordClient := mock_client.NewMockClient(t) - coordClient.EXPECT().Close().Return() + coordClient.EXPECT().Close().Return().Maybe() broadcastServce := mock_client.NewMockBroadcastService(t) broadcastServce.EXPECT().Broadcast(mock.Anything, mock.Anything).RunAndReturn( func(ctx context.Context, bmm message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) { @@ -46,11 +52,13 @@ func TestWAL(t *testing.T) { return &types.BroadcastAppendResult{ AppendResults: result, }, nil - }) - broadcastServce.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil) - coordClient.EXPECT().Broadcast().Return(broadcastServce) + }).Maybe() + broadcastServce.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil).Maybe() + coordClient.EXPECT().Broadcast().Return(broadcastServce).Maybe() handler := mock_handler.NewMockHandlerClient(t) - handler.EXPECT().Close().Return() + c := mock_consumer.NewMockConsumer(t) + handler.EXPECT().CreateConsumer(mock.Anything, mock.Anything).Return(c, nil).Maybe() + handler.EXPECT().Close().Return().Maybe() w := &walAccesserImpl{ lifetime: typeutil.NewLifetime(), @@ -61,8 +69,12 @@ func TestWAL(t *testing.T) { appendExecutionPool: conc.NewPool[struct{}](10), dispatchExecutionPool: conc.NewPool[struct{}](10), } + return w, coordClient, broadcastServce, handler +} +func TestWAL(t *testing.T) { ctx := context.Background() + w, _, _, handler := createMockWAL(t) available := make(chan struct{}) p := mock_producer.NewMockProducer(t) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index b0e75284af..aa8cb7353e 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -32,13 +32,11 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer" "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" @@ -48,9 +46,6 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/segcorepb" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/options" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" @@ -784,24 +779,7 @@ func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, positio return dispatcher.Chan(), dispatcher.Close, nil } -func (sd *shardDelegator) createDeleteStreamFromStreamingService(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) { - handler := adaptor.NewMsgPackAdaptorHandler() - s := streaming.WAL().Read(ctx, streaming.ReadOption{ - VChannel: position.GetChannelName(), - DeliverPolicy: options.DeliverPolicyStartFrom( - adaptor.MustGetMessageIDFromMQWrapperIDBytes(streaming.WAL().WALName(), position.GetMsgID()), - ), - DeliverFilters: []options.DeliverFilter{ - // only deliver message which timestamp >= position.Timestamp - options.DeliverFilterTimeTickGTE(position.GetTimestamp()), - // only delete message - options.DeliverFilterMessageType(message.MessageTypeDelete), - }, - MessageHandler: handler, - }) - return handler.Chan(), s.Close, nil -} - +// Only used in test. func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) { log := sd.getLogger(ctx).With( zap.String("channel", position.ChannelName), @@ -812,11 +790,7 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position var ch <-chan *msgstream.MsgPack var closer func() var err error - if streamingutil.IsStreamingServiceEnabled() { - ch, closer, err = sd.createDeleteStreamFromStreamingService(ctx, position) - } else { - ch, closer, err = sd.createStreamFromMsgStream(ctx, position) - } + ch, closer, err = sd.createStreamFromMsgStream(ctx, position) if closer != nil { defer closer() } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 455a03f502..265d9ecdfb 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -49,6 +49,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" + "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/pipeline" @@ -62,6 +63,7 @@ import ( "github.com/milvus-io/milvus/internal/util/searchutil/scheduler" "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/config" "github.com/milvus-io/milvus/pkg/v2/log" @@ -392,7 +394,11 @@ func (node *QueryNode) Init() error { node.manager = segments.NewManager() node.loader = segments.NewLoader(node.ctx, node.manager, node.chunkManager) node.manager.SetLoader(node.loader) - node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, node.GetNodeID()) + if streamingutil.IsStreamingServiceEnabled() { + node.dispClient = msgdispatcher.NewClient(streaming.NewDelegatorMsgstreamFactory(), typeutil.QueryNodeRole, node.GetNodeID()) + } else { + node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, node.GetNodeID()) + } // init pipeline manager node.pipelineManager = pipeline.NewManager(node.manager, node.dispClient, node.delegators) diff --git a/internal/streamingnode/server/flusher/flusherimpl/data_service_wrapper.go b/internal/streamingnode/server/flusher/flusherimpl/data_service_wrapper.go index 6c17b1d211..8af34a14f3 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/data_service_wrapper.go +++ b/internal/streamingnode/server/flusher/flusherimpl/data_service_wrapper.go @@ -42,10 +42,12 @@ func (ds *dataSyncServiceWrapper) Start() { func (ds *dataSyncServiceWrapper) HandleMessage(ctx context.Context, msg message.ImmutableMessage) error { ds.handler.GenerateMsgPack(msg) for ds.handler.PendingMsgPack.Len() > 0 { + next := ds.handler.PendingMsgPack.Next() + nextTsMsg := msgstream.MustBuildMsgPackFromConsumeMsgPack(next, adaptor.UnmashalerDispatcher) select { case <-ctx.Done(): return ctx.Err() - case ds.input <- ds.handler.PendingMsgPack.Next(): + case ds.input <- nextTsMsg: // The input channel will never get stuck because the data sync service will consume the message continuously. ds.handler.PendingMsgPack.UnsafeAdvance() } diff --git a/internal/util/pipeline/stream_pipeline.go b/internal/util/pipeline/stream_pipeline.go index a9800a3941..33885deb78 100644 --- a/internal/util/pipeline/stream_pipeline.go +++ b/internal/util/pipeline/stream_pipeline.go @@ -27,14 +27,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/distributed/streaming" - "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/mq/common" "github.com/milvus-io/milvus/pkg/v2/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/options" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" ) @@ -98,31 +94,6 @@ func (p *streamPipeline) ConsumeMsgStream(ctx context.Context, position *msgpb.M return ErrNilPosition } - if streamingutil.IsStreamingServiceEnabled() { - startFrom := adaptor.MustGetMessageIDFromMQWrapperIDBytes(streaming.WAL().WALName(), position.GetMsgID()) - log.Info( - "stream pipeline seeks from position with scanner", - zap.String("channel", position.GetChannelName()), - zap.Any("startFromMessageID", startFrom), - zap.Uint64("timestamp", position.GetTimestamp()), - ) - handler := adaptor.NewMsgPackAdaptorHandler() - p.scanner = streaming.WAL().Read(ctx, streaming.ReadOption{ - VChannel: position.GetChannelName(), - DeliverPolicy: options.DeliverPolicyStartFrom(startFrom), - DeliverFilters: []options.DeliverFilter{ - // only consume messages with timestamp >= position timestamp - options.DeliverFilterTimeTickGTE(position.GetTimestamp()), - // only consume insert and delete messages - // also schema change message to notify schema change events - options.DeliverFilterMessageType(message.MessageTypeInsert, message.MessageTypeDelete, message.MessageTypeSchemaChange), - }, - MessageHandler: handler, - }) - p.input = handler.Chan() - return nil - } - start := time.Now() p.input, err = p.dispatcher.Register(ctx, &msgdispatcher.StreamConfig{ VChannel: p.vChannel, diff --git a/pkg/mq/msgstream/msgstream_util.go b/pkg/mq/msgstream/msgstream_util.go index 215444554c..1c9f2553ab 100644 --- a/pkg/mq/msgstream/msgstream_util.go +++ b/pkg/mq/msgstream/msgstream_util.go @@ -173,3 +173,20 @@ func BuildConsumeMsgPack(pack *MsgPack) *ConsumeMsgPack { EndPositions: pack.EndPositions, } } + +// MustBuildMsgPackFromConsumeMsgPack is a helper function to build MsgPack from ConsumeMsgPack. +func MustBuildMsgPackFromConsumeMsgPack(pack *ConsumeMsgPack, unmarshaler UnmarshalDispatcher) *MsgPack { + return &MsgPack{ + BeginTs: pack.BeginTs, + EndTs: pack.EndTs, + Msgs: lo.Map(pack.Msgs, func(msg ConsumeMsg, _ int) TsMsg { + tsMsg, err := msg.Unmarshal(unmarshaler) + if err != nil { + panic("failed to unmarshal msg: " + err.Error()) + } + return tsMsg + }), + StartPositions: pack.StartPositions, + EndPositions: pack.EndPositions, + } +} diff --git a/pkg/streaming/util/message/adaptor/handler.go b/pkg/streaming/util/message/adaptor/handler.go index c3a2b6d4e7..ce7303b65e 100644 --- a/pkg/streaming/util/message/adaptor/handler.go +++ b/pkg/streaming/util/message/adaptor/handler.go @@ -36,18 +36,18 @@ func (d ChanMessageHandler) Close() { // NewMsgPackAdaptorHandler create a new message pack adaptor handler. func NewMsgPackAdaptorHandler() *MsgPackAdaptorHandler { return &MsgPackAdaptorHandler{ - channel: make(chan *msgstream.MsgPack), + channel: make(chan *msgstream.ConsumeMsgPack), base: NewBaseMsgPackAdaptorHandler(), } } type MsgPackAdaptorHandler struct { - channel chan *msgstream.MsgPack + channel chan *msgstream.ConsumeMsgPack base *BaseMsgPackAdaptorHandler } // Chan is the channel for message. -func (m *MsgPackAdaptorHandler) Chan() <-chan *msgstream.MsgPack { +func (m *MsgPackAdaptorHandler) Chan() <-chan *msgstream.ConsumeMsgPack { return m.channel } @@ -61,7 +61,7 @@ func (m *MsgPackAdaptorHandler) Handle(param message.HandleParam) message.Handle } for { - var sendCh chan<- *msgstream.MsgPack + var sendCh chan<- *msgstream.ConsumeMsgPack if m.base.PendingMsgPack.Len() != 0 { sendCh = m.channel } @@ -100,15 +100,15 @@ func NewBaseMsgPackAdaptorHandler() *BaseMsgPackAdaptorHandler { return &BaseMsgPackAdaptorHandler{ Logger: log.With(), Pendings: make([]message.ImmutableMessage, 0), - PendingMsgPack: typeutil.NewMultipartQueue[*msgstream.MsgPack](), + PendingMsgPack: typeutil.NewMultipartQueue[*msgstream.ConsumeMsgPack](), } } // BaseMsgPackAdaptorHandler is the handler for message pack. type BaseMsgPackAdaptorHandler struct { Logger *log.MLogger - Pendings []message.ImmutableMessage // pendings hold the vOld message which has same time tick. - PendingMsgPack *typeutil.MultipartQueue[*msgstream.MsgPack] // pendingMsgPack hold unsent msgPack. + Pendings []message.ImmutableMessage // pendings hold the vOld message which has same time tick. + PendingMsgPack *typeutil.MultipartQueue[*msgstream.ConsumeMsgPack] // pendingMsgPack hold unsent msgPack. } // GenerateMsgPack generate msgPack from message. @@ -142,6 +142,6 @@ func (m *BaseMsgPackAdaptorHandler) addMsgPackIntoPending(msgs ...message.Immuta m.Logger.Warn("failed to convert message to msgpack", zap.Error(err)) } if newPack != nil { - m.PendingMsgPack.AddOne(newPack) + m.PendingMsgPack.AddOne(msgstream.BuildConsumeMsgPack(newPack)) } } diff --git a/tests/go_client/testcases/collection_test.go b/tests/go_client/testcases/collection_test.go index cd77f961ef..99ea342803 100644 --- a/tests/go_client/testcases/collection_test.go +++ b/tests/go_client/testcases/collection_test.go @@ -61,13 +61,13 @@ func TestCreateCollectionFast(t *testing.T) { prepare, _ := hp.CollPrepare.InsertData(ctx, t, mc, hp.NewInsertParams(coll.Schema), hp.TNewDataOption()) prepare.FlushData(ctx, t, mc, collName) - countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields(common.QueryCountFieldName)) + countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) count, _ := countRes.Fields[0].GetAsInt64(0) require.EqualValues(t, common.DefaultNb, count) vectors := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector) - resSearch, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, vectors)) + resSearch, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, vectors).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) common.CheckSearchResult(t, resSearch, common.DefaultNq, common.DefaultLimit) } diff --git a/tests/python_client/testcases/test_partition_key_isolation.py b/tests/python_client/testcases/test_partition_key_isolation.py index f96f6a6e26..ef1cb09eb6 100644 --- a/tests/python_client/testcases/test_partition_key_isolation.py +++ b/tests/python_client/testcases/test_partition_key_isolation.py @@ -116,7 +116,8 @@ class TestPartitionKeyIsolation(TestcaseBase): expr=expr, param={"metric_type": "L2", "params": {"nprobe": 16}}, limit=10000, - output_fields=["scalar_3", "scalar_6", "scalar_12"] + output_fields=["scalar_3", "scalar_6", "scalar_12"], + consistency_level="Strong" ) log.info(f"search res {res}") true_res = all_df.query(expr) @@ -218,7 +219,8 @@ class TestPartitionKeyIsolation(TestcaseBase): expr=expr, param={"metric_type": "L2", "params": {"nprobe": 16}}, limit=10, - output_fields=["scalar_6"] + output_fields=["scalar_6"], + consistency_level="Strong" ) log.info(f"search with {expr} get res {res}") false_result.append(expr) @@ -348,7 +350,8 @@ class TestPartitionKeyIsolation(TestcaseBase): expr="scalar_6 == '1' and scalar_3 == '1'", param={"metric_type": "L2", "params": {"nprobe": 16}}, limit=10, - output_fields=["scalar_6", "scalar_3"] + output_fields=["scalar_6", "scalar_3"], + consistency_level="Strong" ) log.info(f"search res {res}") assert len(res[0]) > 0 diff --git a/tests/python_client/testcases/test_phrase_match.py b/tests/python_client/testcases/test_phrase_match.py index 7e53224cec..d4ab9eb389 100644 --- a/tests/python_client/testcases/test_phrase_match.py +++ b/tests/python_client/testcases/test_phrase_match.py @@ -8,6 +8,7 @@ from common.common_type import CheckTasks from utils.util_log import test_log as log from common import common_func as cf from base.client_base import TestcaseBase +import time prefix = "phrase_match" @@ -72,6 +73,7 @@ class TestQueryPhraseMatch(TestcaseBase): collection_w = self.init_collection_wrap( name=cf.gen_unique_str(prefix), schema=init_collection_schema(dim, tokenizer, enable_partition_key), + consistency_level="Strong", ) # Generate test data @@ -174,6 +176,7 @@ class TestQueryPhraseMatch(TestcaseBase): collection_w = self.init_collection_wrap( name=cf.gen_unique_str(prefix), schema=init_collection_schema(dim, tokenizer, enable_partition_key), + consistency_level="Strong", ) # Generate test data @@ -251,6 +254,7 @@ class TestQueryPhraseMatch(TestcaseBase): collection_w = self.init_collection_wrap( name=cf.gen_unique_str(prefix), schema=init_collection_schema(dim, tokenizer, enable_partition_key), + consistency_level="Strong", ) # Generate test data @@ -327,7 +331,7 @@ class TestQueryPhraseMatch(TestcaseBase): dim = 128 collection_name = f"{prefix}_patterns" schema = init_collection_schema(dim, "standard", False) - collection = self.init_collection_wrap(name=collection_name, schema=schema) + collection = self.init_collection_wrap(name=collection_name, schema=schema, consistency_level="Strong") # Generate data with various patterns generator = PhraseMatchTestGenerator(language="en") @@ -362,10 +366,11 @@ class TestQueryPhraseMatch(TestcaseBase): }, ) collection.load() + time.sleep(1) for pattern, slop in test_patterns: results, _ = collection.query( - expr=f'phrase_match(text, "{pattern}", {slop})', output_fields=["text"] + expr=f'phrase_match(text, "{pattern}", {slop})', output_fields=["text"], ) log.info( f"Pattern '{pattern}' with slop {slop} found {len(results)} matches" @@ -391,7 +396,7 @@ class TestQueryPhraseMatchNegative(TestcaseBase): dim = 128 collection_name = f"{prefix}_invalid_slop" schema = init_collection_schema(dim, "standard", False) - collection = self.init_collection_wrap(name=collection_name, schema=schema) + collection = self.init_collection_wrap(name=collection_name, schema=schema, consistency_level="Strong") # Insert some test data generator = PhraseMatchTestGenerator(language="en")