diff --git a/Makefile b/Makefile index 2af5c2fc74..721c6bdea7 100644 --- a/Makefile +++ b/Makefile @@ -537,7 +537,7 @@ generate-mockery-chunk-manager: getdeps generate-mockery-pkg: $(MAKE) -C pkg generate-mockery -generate-mockery-internal: +generate-mockery-internal: getdeps $(INSTALL_PATH)/mockery --config $(PWD)/internal/.mockery.yaml generate-mockery: generate-mockery-types generate-mockery-kv generate-mockery-rootcoord generate-mockery-proxy generate-mockery-querycoord generate-mockery-querynode generate-mockery-datacoord generate-mockery-pkg generate-mockery-internal diff --git a/cmd/roles/roles.go b/cmd/roles/roles.go index 51e9fd37d4..dcbe628300 100644 --- a/cmd/roles/roles.go +++ b/cmd/roles/roles.go @@ -35,12 +35,14 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/cmd/components" + "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/http/healthz" "github.com/milvus-io/milvus/internal/util/dependency" kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" "github.com/milvus-io/milvus/internal/util/initcore" internalmetrics "github.com/milvus-io/milvus/internal/util/metrics" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -377,6 +379,12 @@ func (mr *MilvusRoles) Run() { paramtable.SetRole(mr.ServerType) } + // Initialize streaming service if enabled. + if streamingutil.IsStreamingServiceEnabled() { + streaming.Init() + defer streaming.Release() + } + expr.Init() expr.Register("param", paramtable.Get()) mr.setupLogger() diff --git a/cmd/tools/config/generate.go b/cmd/tools/config/generate.go index fe8b1e39a2..cdc5cd15bf 100644 --- a/cmd/tools/config/generate.go +++ b/cmd/tools/config/generate.go @@ -328,6 +328,11 @@ func WriteYaml(w io.Writer) { #milvus will automatically initialize half of the available GPU memory, #maxMemSize will the whole available GPU memory.`, }, + { + name: "streamingNode", + header: ` +# Any configuration related to the streaming node server.`, + }, } marshller := YamlMarshaller{w, groups, result} marshller.writeYamlRecursive(lo.Filter(result, func(d DocContent, _ int) bool { diff --git a/configs/milvus.yaml b/configs/milvus.yaml index c8c74abecc..f931e3cbde 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1034,3 +1034,13 @@ trace: gpu: initMemSize: 2048 # Gpu Memory Pool init size maxMemSize: 4096 # Gpu Memory Pool Max size + +# Any configuration related to the streaming node server. +streamingNode: + ip: # TCP/IP address of streamingNode. If not specified, use the first unicastable address + port: 22222 # TCP port of streamingNode + grpc: + serverMaxSendSize: 268435456 # The maximum size of each RPC request that the streamingNode can send, unit: byte + serverMaxRecvSize: 268435456 # The maximum size of each RPC request that the streamingNode can receive, unit: byte + clientMaxSendSize: 268435456 # The maximum size of each RPC request that the clients on streamingNode can send, unit: byte + clientMaxRecvSize: 268435456 # The maximum size of each RPC request that the clients on streamingNode can receive, unit: byte diff --git a/docker-compose.yml b/docker-compose.yml index bb99f4c0ce..3e20108b71 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -97,14 +97,23 @@ services: - ETCD_QUOTA_BACKEND_BYTES=4294967296 - ETCD_SNAPSHOT_COUNT=50000 healthcheck: - test: ['CMD', '/opt/bitnami/scripts/etcd/healthcheck.sh'] + test: [ 'CMD', '/opt/bitnami/scripts/etcd/healthcheck.sh' ] interval: 30s timeout: 20s retries: 3 pulsar: image: apachepulsar/pulsar:2.8.2 - command: bin/pulsar standalone --no-functions-worker --no-stream-storage + command: | + /bin/bash -c \ + "bin/apply-config-from-env.py conf/standalone.conf && \ + exec bin/pulsar standalone --no-functions-worker --no-stream-storage" + environment: + # 10MB + - PULSAR_PREFIX_maxMessageSize=10485760 + # this is 104857600 + 10240 (padding) + - nettyMaxFrameSizeBytes=104867840 + - PULSAR_GC=-XX:+UseG1GC minio: image: minio/minio:RELEASE.2023-03-20T20-16-18Z @@ -113,7 +122,7 @@ services: MINIO_SECRET_KEY: minioadmin command: minio server /minio_data healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] + test: [ 'CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live' ] interval: 30s timeout: 20s retries: 3 diff --git a/go.mod b/go.mod index 2992847262..2cc19b1386 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.17.7 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240815123953-6dab6fcd6454 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240820032106-b34be93a2271 github.com/minio/minio-go/v7 v7.0.61 github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 github.com/prometheus/client_golang v1.14.0 diff --git a/go.sum b/go.sum index b5dc7c85b0..b34026898d 100644 --- a/go.sum +++ b/go.sum @@ -598,8 +598,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240815123953-6dab6fcd6454 h1:JmZCYjMPpiE4ksZw0AUxXWkDY7wwA4fhS+SO1N211Vw= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240815123953-6dab6fcd6454/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240820032106-b34be93a2271 h1:YUWBgtRHmvkxMPTfOrY3FIq0K5XHw02Z18z7cyaMH04= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240820032106-b34be93a2271/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index b4ebdb118f..76774ead73 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/conc" @@ -160,7 +161,7 @@ func (m *ChannelManagerImpl) Startup(ctx context.Context, legacyNodes, allNodes m.finishRemoveChannel(info.NodeID, lo.Values(info.Channels)...) } - if m.balanceCheckLoop != nil { + if m.balanceCheckLoop != nil && !streamingutil.IsStreamingServiceEnabled() { log.Info("starting channel balance loop") m.wg.Add(1) go func() { @@ -329,6 +330,12 @@ func (m *ChannelManagerImpl) Balance() { } func (m *ChannelManagerImpl) Match(nodeID UniqueID, channel string) bool { + if streamingutil.IsStreamingServiceEnabled() { + // Skip the channel matching check since the + // channel manager no longer manages channels in streaming mode. + return true + } + m.mu.RLock() defer m.mu.RUnlock() diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 137990b74b..25290eec34 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -724,7 +724,10 @@ func (s *Server) startServerLoop() { go s.importScheduler.Start() go s.importChecker.Start() s.garbageCollector.start() - s.syncSegmentsScheduler.Start() + + if !streamingutil.IsStreamingServiceEnabled() { + s.syncSegmentsScheduler.Start() + } } func (s *Server) updateSegmentStatistics(stats []*commonpb.SegmentStats) { diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index a63f243906..6eb60e5600 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/importutilv2" "github.com/milvus-io/milvus/internal/util/segmentutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -111,14 +112,16 @@ func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F } timeOfSeal, _ := tsoutil.ParseTS(ts) - sealedSegmentIDs, err := s.segmentManager.SealAllSegments(ctx, req.GetCollectionID(), req.GetSegmentIDs()) - if err != nil { - return &datapb.FlushResponse{ - Status: merr.Status(errors.Wrapf(err, "failed to flush collection %d", - req.GetCollectionID())), - }, nil + sealedSegmentIDs := make([]int64, 0) + if !streamingutil.IsStreamingServiceEnabled() { + var err error + if sealedSegmentIDs, err = s.segmentManager.SealAllSegments(ctx, req.GetCollectionID(), req.GetSegmentIDs()); err != nil { + return &datapb.FlushResponse{ + Status: merr.Status(errors.Wrapf(err, "failed to flush collection %d", + req.GetCollectionID())), + }, nil + } } - sealedSegmentsIDDict := make(map[UniqueID]bool) for _, sealedSegmentID := range sealedSegmentIDs { sealedSegmentsIDDict[sealedSegmentID] = true @@ -135,33 +138,35 @@ func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F } } - var isUnimplemented bool - err = retry.Do(ctx, func() error { - nodeChannels := s.channelManager.GetNodeChannelsByCollectionID(req.GetCollectionID()) + if !streamingutil.IsStreamingServiceEnabled() { + var isUnimplemented bool + err = retry.Do(ctx, func() error { + nodeChannels := s.channelManager.GetNodeChannelsByCollectionID(req.GetCollectionID()) - for nodeID, channelNames := range nodeChannels { - err = s.cluster.FlushChannels(ctx, nodeID, ts, channelNames) - if err != nil && errors.Is(err, merr.ErrServiceUnimplemented) { - isUnimplemented = true - return nil - } - if err != nil { - return err + for nodeID, channelNames := range nodeChannels { + err = s.cluster.FlushChannels(ctx, nodeID, ts, channelNames) + if err != nil && errors.Is(err, merr.ErrServiceUnimplemented) { + isUnimplemented = true + return nil + } + if err != nil { + return err + } } + return nil + }, retry.Attempts(60)) // about 3min + if err != nil { + return &datapb.FlushResponse{ + Status: merr.Status(err), + }, nil } - return nil - }, retry.Attempts(60)) // about 3min - if err != nil { - return &datapb.FlushResponse{ - Status: merr.Status(err), - }, nil - } - if isUnimplemented { - // For compatible with rolling upgrade from version 2.2.x, - // fall back to the flush logic of version 2.2.x; - log.Warn("DataNode FlushChannels unimplemented", zap.Error(err)) - ts = 0 + if isUnimplemented { + // For compatible with rolling upgrade from version 2.2.x, + // fall back to the flush logic of version 2.2.x; + log.Warn("DataNode FlushChannels unimplemented", zap.Error(err)) + ts = 0 + } } log.Info("flush response with segments", @@ -255,6 +260,12 @@ func (s *Server) AllocSegment(ctx context.Context, req *datapb.AllocSegmentReque return &datapb.AllocSegmentResponse{Status: merr.Status(merr.ErrParameterInvalid)}, nil } + // refresh the meta of the collection. + _, err := s.handler.GetCollection(ctx, req.GetCollectionId()) + if err != nil { + return &datapb.AllocSegmentResponse{Status: merr.Status(err)}, nil + } + // Alloc new growing segment and return the segment info. segmentInfo, err := s.segmentManager.AllocNewGrowingSegment(ctx, req.GetCollectionId(), req.GetPartitionId(), req.GetSegmentId(), req.GetVchannel()) if err != nil { diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 871d138438..ba9527ef2a 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -49,6 +49,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -308,21 +309,23 @@ func (node *DataNode) Start() error { return } - node.writeBufferManager.Start() + if !streamingutil.IsStreamingServiceEnabled() { + node.writeBufferManager.Start() + + node.timeTickSender = util2.NewTimeTickSender(node.broker, node.session.ServerID, + retry.Attempts(20), retry.Sleep(time.Millisecond*100)) + node.timeTickSender.Start() + + node.channelManager = channel.NewChannelManager(getPipelineParams(node), node.flowgraphManager) + node.channelManager.Start() + + go node.channelCheckpointUpdater.Start() + } go node.compactionExecutor.Start(node.ctx) go node.importScheduler.Start() - node.timeTickSender = util2.NewTimeTickSender(node.broker, node.session.ServerID, - retry.Attempts(20), retry.Sleep(time.Millisecond*100)) - node.timeTickSender.Start() - - go node.channelCheckpointUpdater.Start() - - node.channelManager = channel.NewChannelManager(getPipelineParams(node), node.flowgraphManager) - node.channelManager.Start() - node.UpdateStateCode(commonpb.StateCode_Healthy) }) return startErr diff --git a/internal/distributed/streaming/internal/errs/error.go b/internal/distributed/streaming/internal/errs/error.go index 5001ac442b..ff1b6c4ce6 100644 --- a/internal/distributed/streaming/internal/errs/error.go +++ b/internal/distributed/streaming/internal/errs/error.go @@ -6,7 +6,7 @@ import ( // All error in streamingservice package should be marked by streamingservice/errs package. var ( - ErrClosed = errors.New("closed") - ErrCanceled = errors.New("canceled") - ErrTxnUnavailable = errors.New("transaction unavailable") + ErrClosed = errors.New("closed") + ErrCanceledOrDeadlineExceed = errors.New("canceled or deadline exceed") + ErrUnrecoverable = errors.New("unrecoverable") ) diff --git a/internal/distributed/streaming/internal/producer/producer.go b/internal/distributed/streaming/internal/producer/producer.go index 6c080e5fd3..8372449446 100644 --- a/internal/distributed/streaming/internal/producer/producer.go +++ b/internal/distributed/streaming/internal/producer/producer.go @@ -95,13 +95,14 @@ func (p *ResumableProducer) Produce(ctx context.Context, msg message.MutableMess } // It's ok to stop retry if the error is canceled or deadline exceed. if status.IsCanceled(err) { - return nil, errors.Mark(err, errs.ErrCanceled) + return nil, errors.Mark(err, errs.ErrCanceledOrDeadlineExceed) } if sErr := status.AsStreamingError(err); sErr != nil { - // if the error is txn unavailable, it cannot be retried forever. + // if the error is txn unavailable or unrecoverable error, + // it cannot be retried forever. // we should mark it and return. - if sErr.IsTxnUnavilable() { - return nil, errors.Mark(err, errs.ErrTxnUnavailable) + if sErr.IsUnrecoverable() { + return nil, errors.Mark(err, errs.ErrUnrecoverable) } } } diff --git a/internal/distributed/streaming/internal/producer/producer_resuming.go b/internal/distributed/streaming/internal/producer/producer_resuming.go index 5560901ba1..c0469225f1 100644 --- a/internal/distributed/streaming/internal/producer/producer_resuming.go +++ b/internal/distributed/streaming/internal/producer/producer_resuming.go @@ -30,7 +30,7 @@ func (p *producerWithResumingError) GetProducerAfterAvailable(ctx context.Contex p.cond.L.Lock() for p.err == nil && (p.producer == nil || !p.producer.IsAvailable()) { if err := p.cond.Wait(ctx); err != nil { - return nil, errors.Mark(err, errs.ErrCanceled) + return nil, errors.Mark(err, errs.ErrCanceledOrDeadlineExceed) } } err := p.err diff --git a/internal/distributed/streaming/internal/producer/producer_test.go b/internal/distributed/streaming/internal/producer/producer_test.go index f85134a7fd..2230f0aedf 100644 --- a/internal/distributed/streaming/internal/producer/producer_test.go +++ b/internal/distributed/streaming/internal/producer/producer_test.go @@ -88,7 +88,7 @@ func TestResumableProducer(t *testing.T) { id, err = rp.Produce(ctx, msg) assert.Nil(t, id) assert.Error(t, err) - assert.True(t, errors.Is(err, errs.ErrCanceled)) + assert.True(t, errors.Is(err, errs.ErrCanceledOrDeadlineExceed)) // Test the underlying handler close. close(ch2) diff --git a/internal/distributed/streaming/streaming.go b/internal/distributed/streaming/streaming.go index 3bd722a357..aa39c4e1de 100644 --- a/internal/distributed/streaming/streaming.go +++ b/internal/distributed/streaming/streaming.go @@ -80,14 +80,24 @@ type WALAccesser interface { // Once the txn is returned, the Commit or Rollback operation must be called once, otherwise resource leak on wal. Txn(ctx context.Context, opts TxnOption) (Txn, error) - // Append writes a records to the log. - Append(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) + // RawAppend writes a records to the log. + RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) // Read returns a scanner for reading records from the wal. Read(ctx context.Context, opts ReadOption) Scanner - // Utility returns the utility for writing records to the log. - Utility() Utility + // AppendMessages appends messages to the wal. + // It it a helper utility function to append messages to the wal. + // If the messages is belong to one vchannel, it will be sent as a transaction. + // Otherwise, it will be sent as individual messages. + // !!! This function do not promise the atomicity and deliver order of the messages appending. + // TODO: Remove after we support cross-wal txn. + AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses + + // AppendMessagesWithOption appends messages to the wal with the given option. + // Same with AppendMessages, but with the given option. + // TODO: Remove after we support cross-wal txn. + AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses } // Txn is the interface for writing transaction into the wal. @@ -105,18 +115,3 @@ type Txn interface { // It is preserved for future cross-wal txn. Rollback(ctx context.Context) error } - -type Utility interface { - // AppendMessages appends messages to the wal. - // It it a helper utility function to append messages to the wal. - // If the messages is belong to one vchannel, it will be sent as a transaction. - // Otherwise, it will be sent as individual messages. - // !!! This function do not promise the atomicity and deliver order of the messages appending. - // TODO: Remove after we support cross-wal txn. - AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses - - // AppendMessagesWithOption appends messages to the wal with the given option. - // Same with AppendMessages, but with the given option. - // TODO: Remove after we support cross-wal txn. - AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses -} diff --git a/internal/distributed/streaming/streaming_test.go b/internal/distributed/streaming/streaming_test.go index e0db8bb3eb..c24f652616 100644 --- a/internal/distributed/streaming/streaming_test.go +++ b/internal/distributed/streaming/streaming_test.go @@ -35,7 +35,7 @@ func TestStreamingProduce(t *testing.T) { }). WithVChannel(vChannel). BuildMutable() - resp, err := streaming.WAL().Append(context.Background(), msg) + resp, err := streaming.WAL().RawAppend(context.Background(), msg) fmt.Printf("%+v\t%+v\n", resp, err) for i := 0; i < 500; i++ { @@ -49,7 +49,7 @@ func TestStreamingProduce(t *testing.T) { }). WithVChannel(vChannel). BuildMutable() - resp, err := streaming.WAL().Append(context.Background(), msg) + resp, err := streaming.WAL().RawAppend(context.Background(), msg) fmt.Printf("%+v\t%+v\n", resp, err) } @@ -92,7 +92,7 @@ func TestStreamingProduce(t *testing.T) { }). WithVChannel(vChannel). BuildMutable() - resp, err = streaming.WAL().Append(context.Background(), msg) + resp, err = streaming.WAL().RawAppend(context.Background(), msg) fmt.Printf("%+v\t%+v\n", resp, err) } diff --git a/internal/distributed/streaming/util.go b/internal/distributed/streaming/util.go index 8701ef9462..da1e0156a5 100644 --- a/internal/distributed/streaming/util.go +++ b/internal/distributed/streaming/util.go @@ -7,21 +7,14 @@ import ( "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" - "github.com/milvus-io/milvus/pkg/util/conc" ) -type utility struct { - appendExecutionPool *conc.Pool[struct{}] - dispatchExecutionPool *conc.Pool[struct{}] - *walAccesserImpl -} - // AppendMessagesToWAL appends messages to the wal. // It it a helper utility function to append messages to the wal. // If the messages is belong to one vchannel, it will be sent as a transaction. // Otherwise, it will be sent as individual messages. // !!! This function do not promise the atomicity and deliver order of the messages appending. -func (u *utility) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses { +func (u *walAccesserImpl) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses { assertNoSystemMessage(msgs...) // dispatch the messages into different vchannel. @@ -58,7 +51,7 @@ func (u *utility) AppendMessages(ctx context.Context, msgs ...message.MutableMes } // AppendMessagesWithOption appends messages to the wal with the given option. -func (u *utility) AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses { +func (u *walAccesserImpl) AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses { for _, msg := range msgs { applyOpt(msg, opts) } @@ -66,7 +59,7 @@ func (u *utility) AppendMessagesWithOption(ctx context.Context, opts AppendOptio } // dispatchMessages dispatches the messages into different vchannel. -func (u *utility) dispatchMessages(msgs ...message.MutableMessage) (map[string][]message.MutableMessage, map[string][]int) { +func (u *walAccesserImpl) dispatchMessages(msgs ...message.MutableMessage) (map[string][]message.MutableMessage, map[string][]int) { dispatchedMessages := make(map[string][]message.MutableMessage, 0) indexes := make(map[string][]int, 0) for idx, msg := range msgs { @@ -82,7 +75,7 @@ func (u *utility) dispatchMessages(msgs ...message.MutableMessage) (map[string][ } // appendToVChannel appends the messages to the specified vchannel. -func (u *utility) appendToVChannel(ctx context.Context, vchannel string, msgs ...message.MutableMessage) AppendResponses { +func (u *walAccesserImpl) appendToVChannel(ctx context.Context, vchannel string, msgs ...message.MutableMessage) AppendResponses { if len(msgs) == 0 { return newAppendResponseN(0) } @@ -169,6 +162,16 @@ type AppendResponses struct { Responses []AppendResponse } +func (a AppendResponses) MaxTimeTick() uint64 { + var maxTimeTick uint64 + for _, r := range a.Responses { + if r.AppendResult != nil && r.AppendResult.TimeTick > maxTimeTick { + maxTimeTick = r.AppendResult.TimeTick + } + } + return maxTimeTick +} + // UnwrapFirstError returns the first error in the responses. func (a AppendResponses) UnwrapFirstError() error { for _, r := range a.Responses { diff --git a/internal/distributed/streaming/wal.go b/internal/distributed/streaming/wal.go index e5450226b0..2363e0397e 100644 --- a/internal/distributed/streaming/wal.go +++ b/internal/distributed/streaming/wal.go @@ -32,11 +32,10 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl { handlerClient: handlerClient, producerMutex: sync.Mutex{}, producers: make(map[string]*producer.ResumableProducer), - utility: &utility{ - // TODO: optimize the pool size, use the streaming api but not goroutines. - appendExecutionPool: conc.NewPool[struct{}](10), - dispatchExecutionPool: conc.NewPool[struct{}](10), - }, + + // TODO: optimize the pool size, use the streaming api but not goroutines. + appendExecutionPool: conc.NewPool[struct{}](10), + dispatchExecutionPool: conc.NewPool[struct{}](10), } } @@ -48,13 +47,14 @@ type walAccesserImpl struct { streamingCoordAssignmentClient client.Client handlerClient handler.HandlerClient - producerMutex sync.Mutex - producers map[string]*producer.ResumableProducer - utility *utility + producerMutex sync.Mutex + producers map[string]*producer.ResumableProducer + appendExecutionPool *conc.Pool[struct{}] + dispatchExecutionPool *conc.Pool[struct{}] } -// Append writes a record to the log. -func (w *walAccesserImpl) Append(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) { +// RawAppend writes a record to the log. +func (w *walAccesserImpl) RawAppend(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) { assertNoSystemMessage(msg) if err := w.lifetime.Add(lifetime.IsWorking); err != nil { return nil, status.NewOnShutdownError("wal accesser closed, %s", err.Error()) @@ -125,15 +125,6 @@ func (w *walAccesserImpl) Txn(ctx context.Context, opts TxnOption) (Txn, error) }, nil } -// Utility returns the utility of the wal accesser. -func (w *walAccesserImpl) Utility() Utility { - return &utility{ - appendExecutionPool: w.utility.appendExecutionPool, - dispatchExecutionPool: w.utility.dispatchExecutionPool, - walAccesserImpl: w, - } -} - // Close closes all the wal accesser. func (w *walAccesserImpl) Close() { w.lifetime.SetState(lifetime.Stopped) diff --git a/internal/distributed/streaming/wal_test.go b/internal/distributed/streaming/wal_test.go index 258f30ec00..2ed2ed37d9 100644 --- a/internal/distributed/streaming/wal_test.go +++ b/internal/distributed/streaming/wal_test.go @@ -39,10 +39,8 @@ func TestWAL(t *testing.T) { handlerClient: handler, producerMutex: sync.Mutex{}, producers: make(map[string]*producer.ResumableProducer), - utility: &utility{ - appendExecutionPool: conc.NewPool[struct{}](10), - dispatchExecutionPool: conc.NewPool[struct{}](10), - }, + appendExecutionPool: conc.NewPool[struct{}](10), + dispatchExecutionPool: conc.NewPool[struct{}](10), } defer w.Close() @@ -70,7 +68,7 @@ func TestWAL(t *testing.T) { p.EXPECT().Close().Return() handler.EXPECT().CreateProducer(mock.Anything, mock.Anything).Return(p, nil) - result, err := w.Append(ctx, newInsertMessage(vChannel1)) + result, err := w.RawAppend(ctx, newInsertMessage(vChannel1)) assert.NoError(t, err) assert.NotNil(t, result) @@ -107,7 +105,7 @@ func TestWAL(t *testing.T) { err = txn.Rollback(ctx) assert.NoError(t, err) - resp := w.Utility().AppendMessages(ctx, + resp := w.AppendMessages(ctx, newInsertMessage(vChannel1), newInsertMessage(vChannel2), newInsertMessage(vChannel2), diff --git a/internal/distributed/streamingnode/service.go b/internal/distributed/streamingnode/service.go index 8a909b2324..eec82a7cff 100644 --- a/internal/distributed/streamingnode/service.go +++ b/internal/distributed/streamingnode/service.go @@ -76,21 +76,22 @@ type Server struct { grpcServer *grpc.Server lis net.Listener + factory dependency.Factory + // component client etcdCli *clientv3.Client tikvCli *txnkv.Client rootCoord types.RootCoordClient dataCoord types.DataCoordClient chunkManager storage.ChunkManager - f dependency.Factory } // NewServer create a new StreamingNode server. func NewServer(f dependency.Factory) (*Server, error) { return &Server{ stopOnce: sync.Once{}, + factory: f, grpcServerChan: make(chan struct{}), - f: f, }, nil } @@ -180,6 +181,9 @@ func (s *Server) init(ctx context.Context) (err error) { if err := s.initMeta(); err != nil { return err } + if err := s.initChunkManager(ctx); err != nil { + return err + } if err := s.allocateAddress(); err != nil { return err } @@ -192,14 +196,12 @@ func (s *Server) init(ctx context.Context) (err error) { if err := s.initDataCoord(ctx); err != nil { return err } - if err := s.initChunkManager(ctx); err != nil { - return err - } s.initGRPCServer() // Create StreamingNode service. s.streamingnode = streamingnodeserver.NewServerBuilder(). WithETCD(s.etcdCli). + WithChunkManager(s.chunkManager). WithGRPCServer(s.grpcServer). WithRootCoordClient(s.rootCoord). WithDataCoordClient(s.dataCoord). @@ -305,8 +307,8 @@ func (s *Server) initDataCoord(ctx context.Context) (err error) { func (s *Server) initChunkManager(ctx context.Context) (err error) { log.Info("StreamingNode init chunk manager...") - s.f.Init(paramtable.Get()) - manager, err := s.f.NewPersistentStorageChunkManager(ctx) + s.factory.Init(paramtable.Get()) + manager, err := s.factory.NewPersistentStorageChunkManager(ctx) if err != nil { return errors.Wrap(err, "StreamingNode try to new chunk manager failed") } diff --git a/internal/flushcommon/pipeline/data_sync_service.go b/internal/flushcommon/pipeline/data_sync_service.go index 0db7379b9e..bdf5b3d5ff 100644 --- a/internal/flushcommon/pipeline/data_sync_service.go +++ b/internal/flushcommon/pipeline/data_sync_service.go @@ -33,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/flowgraph" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" @@ -61,8 +62,8 @@ type DataSyncService struct { broker broker.Broker syncMgr syncmgr.SyncManager - timetickSender *util.TimeTickSender // reference to TimeTickSender - compactor compaction.Executor // reference to compaction executor + timetickSender util.StatsUpdater // reference to TimeTickSender + compactor compaction.Executor // reference to compaction executor dispClient msgdispatcher.Client chunkManager storage.ChunkManager @@ -159,7 +160,7 @@ func initMetaCache(initCtx context.Context, chunkManager storage.ChunkManager, i return nil, err } segmentPks.Insert(segment.GetID(), pkoracle.NewBloomFilterSet(stats...)) - if tickler != nil { + if !streamingutil.IsStreamingServiceEnabled() { tickler.Inc() } diff --git a/internal/flushcommon/pipeline/flow_graph_dd_node.go b/internal/flushcommon/pipeline/flow_graph_dd_node.go index 213c08aeb3..14ddf9f7eb 100644 --- a/internal/flushcommon/pipeline/flow_graph_dd_node.go +++ b/internal/flushcommon/pipeline/flow_graph_dd_node.go @@ -33,9 +33,11 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/util/flowgraph" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -154,7 +156,9 @@ func (ddn *ddNode) Operate(in []Msg) []Msg { ddn.dropMode.Store(true) log.Info("Stop compaction for dropped channel", zap.String("channel", ddn.vChannelName)) - ddn.compactionExecutor.DiscardByDroppedChannel(ddn.vChannelName) + if !streamingutil.IsStreamingServiceEnabled() { + ddn.compactionExecutor.DiscardByDroppedChannel(ddn.vChannelName) + } fgMsg.dropCollection = true } @@ -232,10 +236,32 @@ func (ddn *ddNode) Operate(in []Msg) []Msg { WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.DeleteLabel). Add(float64(dmsg.GetNumRows())) fgMsg.DeleteMessages = append(fgMsg.DeleteMessages, dmsg) - - case commonpb.MsgType_Flush: - if ddn.flushMsgHandler != nil { - ddn.flushMsgHandler(ddn.vChannelName, nil) + case commonpb.MsgType_FlushSegment: + flushMsg := msg.(*adaptor.FlushMessageBody) + logger := log.With( + zap.String("vchannel", ddn.Name()), + zap.Int32("msgType", int32(msg.Type())), + zap.Uint64("timetick", flushMsg.FlushMessage.TimeTick()), + ) + logger.Info("receive flush message") + if err := ddn.flushMsgHandler.HandleFlush(ddn.vChannelName, flushMsg.FlushMessage); err != nil { + logger.Warn("handle flush message failed", zap.Error(err)) + } else { + logger.Info("handle flush message success") + } + case commonpb.MsgType_ManualFlush: + manualFlushMsg := msg.(*adaptor.ManualFlushMessageBody) + logger := log.With( + zap.String("vchannel", ddn.Name()), + zap.Int32("msgType", int32(msg.Type())), + zap.Uint64("timetick", manualFlushMsg.ManualFlushMessage.TimeTick()), + zap.Uint64("flushTs", manualFlushMsg.ManualFlushMessage.Header().FlushTs), + ) + logger.Info("receive manual flush message") + if err := ddn.flushMsgHandler.HandleManualFlush(ddn.vChannelName, manualFlushMsg.ManualFlushMessage); err != nil { + logger.Warn("handle manual flush message failed", zap.Error(err)) + } else { + logger.Info("handle manual flush message success") } } } diff --git a/internal/flushcommon/pipeline/flow_graph_write_node.go b/internal/flushcommon/pipeline/flow_graph_write_node.go index 9a8a5ed066..a0f6048f08 100644 --- a/internal/flushcommon/pipeline/flow_graph_write_node.go +++ b/internal/flushcommon/pipeline/flow_graph_write_node.go @@ -14,6 +14,7 @@ import ( "github.com/milvus-io/milvus/internal/flushcommon/metacache" "github.com/milvus-io/milvus/internal/flushcommon/util" "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -98,7 +99,9 @@ func (wNode *writeNode) Operate(in []Msg) []Msg { }, true }) - wNode.updater.Update(wNode.channelName, end.GetTimestamp(), stats) + if !streamingutil.IsStreamingServiceEnabled() { + wNode.updater.Update(wNode.channelName, end.GetTimestamp(), stats) + } res := FlowGraphMsg{ TimeRange: fgMsg.TimeRange, diff --git a/internal/flushcommon/util/util.go b/internal/flushcommon/util/util.go index cfbe6b4507..0f6c8985e7 100644 --- a/internal/flushcommon/util/util.go +++ b/internal/flushcommon/util/util.go @@ -40,7 +40,7 @@ type PipelineParams struct { Ctx context.Context Broker broker.Broker SyncMgr syncmgr.SyncManager - TimeTickSender *TimeTickSender // reference to TimeTickSender + TimeTickSender StatsUpdater // reference to TimeTickSender CompactionExecutor compaction.Executor // reference to compaction executor MsgStreamFactory dependency.Factory DispClient msgdispatcher.Client diff --git a/internal/mocks/distributed/mock_streaming/mock_Utility.go b/internal/mocks/distributed/mock_streaming/mock_Utility.go deleted file mode 100644 index e2fa616a5a..0000000000 --- a/internal/mocks/distributed/mock_streaming/mock_Utility.go +++ /dev/null @@ -1,154 +0,0 @@ -// Code generated by mockery v2.32.4. DO NOT EDIT. - -package mock_streaming - -import ( - context "context" - - message "github.com/milvus-io/milvus/pkg/streaming/util/message" - mock "github.com/stretchr/testify/mock" - - streaming "github.com/milvus-io/milvus/internal/distributed/streaming" -) - -// MockUtility is an autogenerated mock type for the Utility type -type MockUtility struct { - mock.Mock -} - -type MockUtility_Expecter struct { - mock *mock.Mock -} - -func (_m *MockUtility) EXPECT() *MockUtility_Expecter { - return &MockUtility_Expecter{mock: &_m.Mock} -} - -// AppendMessages provides a mock function with given fields: ctx, msgs -func (_m *MockUtility) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { - _va := make([]interface{}, len(msgs)) - for _i := range msgs { - _va[_i] = msgs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 streaming.AppendResponses - if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) streaming.AppendResponses); ok { - r0 = rf(ctx, msgs...) - } else { - r0 = ret.Get(0).(streaming.AppendResponses) - } - - return r0 -} - -// MockUtility_AppendMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessages' -type MockUtility_AppendMessages_Call struct { - *mock.Call -} - -// AppendMessages is a helper method to define mock.On call -// - ctx context.Context -// - msgs ...message.MutableMessage -func (_e *MockUtility_Expecter) AppendMessages(ctx interface{}, msgs ...interface{}) *MockUtility_AppendMessages_Call { - return &MockUtility_AppendMessages_Call{Call: _e.mock.On("AppendMessages", - append([]interface{}{ctx}, msgs...)...)} -} - -func (_c *MockUtility_AppendMessages_Call) Run(run func(ctx context.Context, msgs ...message.MutableMessage)) *MockUtility_AppendMessages_Call { - _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]message.MutableMessage, len(args)-1) - for i, a := range args[1:] { - if a != nil { - variadicArgs[i] = a.(message.MutableMessage) - } - } - run(args[0].(context.Context), variadicArgs...) - }) - return _c -} - -func (_c *MockUtility_AppendMessages_Call) Return(_a0 streaming.AppendResponses) *MockUtility_AppendMessages_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockUtility_AppendMessages_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) streaming.AppendResponses) *MockUtility_AppendMessages_Call { - _c.Call.Return(run) - return _c -} - -// AppendMessagesWithOption provides a mock function with given fields: ctx, opts, msgs -func (_m *MockUtility) AppendMessagesWithOption(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage) streaming.AppendResponses { - _va := make([]interface{}, len(msgs)) - for _i := range msgs { - _va[_i] = msgs[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, opts) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) - - var r0 streaming.AppendResponses - if rf, ok := ret.Get(0).(func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses); ok { - r0 = rf(ctx, opts, msgs...) - } else { - r0 = ret.Get(0).(streaming.AppendResponses) - } - - return r0 -} - -// MockUtility_AppendMessagesWithOption_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessagesWithOption' -type MockUtility_AppendMessagesWithOption_Call struct { - *mock.Call -} - -// AppendMessagesWithOption is a helper method to define mock.On call -// - ctx context.Context -// - opts streaming.AppendOption -// - msgs ...message.MutableMessage -func (_e *MockUtility_Expecter) AppendMessagesWithOption(ctx interface{}, opts interface{}, msgs ...interface{}) *MockUtility_AppendMessagesWithOption_Call { - return &MockUtility_AppendMessagesWithOption_Call{Call: _e.mock.On("AppendMessagesWithOption", - append([]interface{}{ctx, opts}, msgs...)...)} -} - -func (_c *MockUtility_AppendMessagesWithOption_Call) Run(run func(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage)) *MockUtility_AppendMessagesWithOption_Call { - _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]message.MutableMessage, len(args)-2) - for i, a := range args[2:] { - if a != nil { - variadicArgs[i] = a.(message.MutableMessage) - } - } - run(args[0].(context.Context), args[1].(streaming.AppendOption), variadicArgs...) - }) - return _c -} - -func (_c *MockUtility_AppendMessagesWithOption_Call) Return(_a0 streaming.AppendResponses) *MockUtility_AppendMessagesWithOption_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockUtility_AppendMessagesWithOption_Call) RunAndReturn(run func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses) *MockUtility_AppendMessagesWithOption_Call { - _c.Call.Return(run) - return _c -} - -// NewMockUtility creates a new instance of MockUtility. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockUtility(t interface { - mock.TestingT - Cleanup(func()) -}) *MockUtility { - mock := &MockUtility{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go index 5090e413bc..b346552c96 100644 --- a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go +++ b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go @@ -26,8 +26,123 @@ func (_m *MockWALAccesser) EXPECT() *MockWALAccesser_Expecter { return &MockWALAccesser_Expecter{mock: &_m.Mock} } -// Append provides a mock function with given fields: ctx, msgs, opts -func (_m *MockWALAccesser) Append(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption) (*types.AppendResult, error) { +// AppendMessages provides a mock function with given fields: ctx, msgs +func (_m *MockWALAccesser) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { + _va := make([]interface{}, len(msgs)) + for _i := range msgs { + _va[_i] = msgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 streaming.AppendResponses + if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) streaming.AppendResponses); ok { + r0 = rf(ctx, msgs...) + } else { + r0 = ret.Get(0).(streaming.AppendResponses) + } + + return r0 +} + +// MockWALAccesser_AppendMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessages' +type MockWALAccesser_AppendMessages_Call struct { + *mock.Call +} + +// AppendMessages is a helper method to define mock.On call +// - ctx context.Context +// - msgs ...message.MutableMessage +func (_e *MockWALAccesser_Expecter) AppendMessages(ctx interface{}, msgs ...interface{}) *MockWALAccesser_AppendMessages_Call { + return &MockWALAccesser_AppendMessages_Call{Call: _e.mock.On("AppendMessages", + append([]interface{}{ctx}, msgs...)...)} +} + +func (_c *MockWALAccesser_AppendMessages_Call) Run(run func(ctx context.Context, msgs ...message.MutableMessage)) *MockWALAccesser_AppendMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]message.MutableMessage, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(message.MutableMessage) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockWALAccesser_AppendMessages_Call) Return(_a0 streaming.AppendResponses) *MockWALAccesser_AppendMessages_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWALAccesser_AppendMessages_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) streaming.AppendResponses) *MockWALAccesser_AppendMessages_Call { + _c.Call.Return(run) + return _c +} + +// AppendMessagesWithOption provides a mock function with given fields: ctx, opts, msgs +func (_m *MockWALAccesser) AppendMessagesWithOption(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage) streaming.AppendResponses { + _va := make([]interface{}, len(msgs)) + for _i := range msgs { + _va[_i] = msgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, opts) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 streaming.AppendResponses + if rf, ok := ret.Get(0).(func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses); ok { + r0 = rf(ctx, opts, msgs...) + } else { + r0 = ret.Get(0).(streaming.AppendResponses) + } + + return r0 +} + +// MockWALAccesser_AppendMessagesWithOption_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessagesWithOption' +type MockWALAccesser_AppendMessagesWithOption_Call struct { + *mock.Call +} + +// AppendMessagesWithOption is a helper method to define mock.On call +// - ctx context.Context +// - opts streaming.AppendOption +// - msgs ...message.MutableMessage +func (_e *MockWALAccesser_Expecter) AppendMessagesWithOption(ctx interface{}, opts interface{}, msgs ...interface{}) *MockWALAccesser_AppendMessagesWithOption_Call { + return &MockWALAccesser_AppendMessagesWithOption_Call{Call: _e.mock.On("AppendMessagesWithOption", + append([]interface{}{ctx, opts}, msgs...)...)} +} + +func (_c *MockWALAccesser_AppendMessagesWithOption_Call) Run(run func(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage)) *MockWALAccesser_AppendMessagesWithOption_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]message.MutableMessage, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(message.MutableMessage) + } + } + run(args[0].(context.Context), args[1].(streaming.AppendOption), variadicArgs...) + }) + return _c +} + +func (_c *MockWALAccesser_AppendMessagesWithOption_Call) Return(_a0 streaming.AppendResponses) *MockWALAccesser_AppendMessagesWithOption_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWALAccesser_AppendMessagesWithOption_Call) RunAndReturn(run func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses) *MockWALAccesser_AppendMessagesWithOption_Call { + _c.Call.Return(run) + return _c +} + +// RawAppend provides a mock function with given fields: ctx, msgs, opts +func (_m *MockWALAccesser) RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption) (*types.AppendResult, error) { _va := make([]interface{}, len(opts)) for _i := range opts { _va[_i] = opts[_i] @@ -59,21 +174,21 @@ func (_m *MockWALAccesser) Append(ctx context.Context, msgs message.MutableMessa return r0, r1 } -// MockWALAccesser_Append_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Append' -type MockWALAccesser_Append_Call struct { +// MockWALAccesser_RawAppend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RawAppend' +type MockWALAccesser_RawAppend_Call struct { *mock.Call } -// Append is a helper method to define mock.On call +// RawAppend is a helper method to define mock.On call // - ctx context.Context // - msgs message.MutableMessage // - opts ...streaming.AppendOption -func (_e *MockWALAccesser_Expecter) Append(ctx interface{}, msgs interface{}, opts ...interface{}) *MockWALAccesser_Append_Call { - return &MockWALAccesser_Append_Call{Call: _e.mock.On("Append", +func (_e *MockWALAccesser_Expecter) RawAppend(ctx interface{}, msgs interface{}, opts ...interface{}) *MockWALAccesser_RawAppend_Call { + return &MockWALAccesser_RawAppend_Call{Call: _e.mock.On("RawAppend", append([]interface{}{ctx, msgs}, opts...)...)} } -func (_c *MockWALAccesser_Append_Call) Run(run func(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption)) *MockWALAccesser_Append_Call { +func (_c *MockWALAccesser_RawAppend_Call) Run(run func(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption)) *MockWALAccesser_RawAppend_Call { _c.Call.Run(func(args mock.Arguments) { variadicArgs := make([]streaming.AppendOption, len(args)-2) for i, a := range args[2:] { @@ -86,12 +201,12 @@ func (_c *MockWALAccesser_Append_Call) Run(run func(ctx context.Context, msgs me return _c } -func (_c *MockWALAccesser_Append_Call) Return(_a0 *types.AppendResult, _a1 error) *MockWALAccesser_Append_Call { +func (_c *MockWALAccesser_RawAppend_Call) Return(_a0 *types.AppendResult, _a1 error) *MockWALAccesser_RawAppend_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockWALAccesser_Append_Call) RunAndReturn(run func(context.Context, message.MutableMessage, ...streaming.AppendOption) (*types.AppendResult, error)) *MockWALAccesser_Append_Call { +func (_c *MockWALAccesser_RawAppend_Call) RunAndReturn(run func(context.Context, message.MutableMessage, ...streaming.AppendOption) (*types.AppendResult, error)) *MockWALAccesser_RawAppend_Call { _c.Call.Return(run) return _c } @@ -196,49 +311,6 @@ func (_c *MockWALAccesser_Txn_Call) RunAndReturn(run func(context.Context, strea return _c } -// Utility provides a mock function with given fields: -func (_m *MockWALAccesser) Utility() streaming.Utility { - ret := _m.Called() - - var r0 streaming.Utility - if rf, ok := ret.Get(0).(func() streaming.Utility); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(streaming.Utility) - } - } - - return r0 -} - -// MockWALAccesser_Utility_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Utility' -type MockWALAccesser_Utility_Call struct { - *mock.Call -} - -// Utility is a helper method to define mock.On call -func (_e *MockWALAccesser_Expecter) Utility() *MockWALAccesser_Utility_Call { - return &MockWALAccesser_Utility_Call{Call: _e.mock.On("Utility")} -} - -func (_c *MockWALAccesser_Utility_Call) Run(run func()) *MockWALAccesser_Utility_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockWALAccesser_Utility_Call) Return(_a0 streaming.Utility) *MockWALAccesser_Utility_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockWALAccesser_Utility_Call) RunAndReturn(run func() streaming.Utility) *MockWALAccesser_Utility_Call { - _c.Call.Return(run) - return _c -} - // NewMockWALAccesser creates a new instance of MockWALAccesser. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockWALAccesser(t interface { diff --git a/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector/mock_SealOperator.go b/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector/mock_SealOperator.go index 697da341fe..2346b3cd09 100644 --- a/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector/mock_SealOperator.go +++ b/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector/mock_SealOperator.go @@ -107,6 +107,54 @@ func (_c *MockSealOperator_IsNoWaitSeal_Call) RunAndReturn(run func() bool) *Moc return _c } +// MustSealSegments provides a mock function with given fields: ctx, infos +func (_m *MockSealOperator) MustSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) { + _va := make([]interface{}, len(infos)) + for _i := range infos { + _va[_i] = infos[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// MockSealOperator_MustSealSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MustSealSegments' +type MockSealOperator_MustSealSegments_Call struct { + *mock.Call +} + +// MustSealSegments is a helper method to define mock.On call +// - ctx context.Context +// - infos ...stats.SegmentBelongs +func (_e *MockSealOperator_Expecter) MustSealSegments(ctx interface{}, infos ...interface{}) *MockSealOperator_MustSealSegments_Call { + return &MockSealOperator_MustSealSegments_Call{Call: _e.mock.On("MustSealSegments", + append([]interface{}{ctx}, infos...)...)} +} + +func (_c *MockSealOperator_MustSealSegments_Call) Run(run func(ctx context.Context, infos ...stats.SegmentBelongs)) *MockSealOperator_MustSealSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]stats.SegmentBelongs, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(stats.SegmentBelongs) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockSealOperator_MustSealSegments_Call) Return() *MockSealOperator_MustSealSegments_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSealOperator_MustSealSegments_Call) RunAndReturn(run func(context.Context, ...stats.SegmentBelongs)) *MockSealOperator_MustSealSegments_Call { + _c.Call.Return(run) + return _c +} + // TryToSealSegments provides a mock function with given fields: ctx, infos func (_m *MockSealOperator) TryToSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) { _va := make([]interface{}, len(infos)) diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index d44c010563..1ee1a0f214 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -45,6 +45,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/importutilv2" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -2534,6 +2535,12 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) chMgr: node.chMgr, chTicker: node.chTicker, } + var enqueuedTask task = it + if streamingutil.IsStreamingServiceEnabled() { + enqueuedTask = &insertTaskByStreamingService{ + insertTask: it, + } + } constructFailedResponse := func(err error) *milvuspb.MutationResult { numRows := request.NumRows @@ -2550,7 +2557,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) log.Debug("Enqueue insert request in Proxy") - if err := node.sched.dmQueue.Enqueue(it); err != nil { + if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil { log.Warn("Failed to enqueue insert task: " + err.Error()) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel, request.GetDbName(), request.GetCollectionName()).Inc() @@ -2769,12 +2776,18 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) chMgr: node.chMgr, chTicker: node.chTicker, } + var enqueuedTask task = it + if streamingutil.IsStreamingServiceEnabled() { + enqueuedTask = &upsertTaskByStreamingService{ + upsertTask: it, + } + } log.Debug("Enqueue upsert request in Proxy", zap.Int("len(FieldsData)", len(request.FieldsData)), zap.Int("len(HashKeys)", len(request.HashKeys))) - if err := node.sched.dmQueue.Enqueue(it); err != nil { + if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil { log.Info("Failed to enqueue upsert task", zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, @@ -3376,7 +3389,15 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (* log.Debug(rpcReceived(method)) - if err := node.sched.dcQueue.Enqueue(ft); err != nil { + var enqueuedTask task = ft + if streamingutil.IsStreamingServiceEnabled() { + enqueuedTask = &flushTaskByStreamingService{ + flushTask: ft, + chMgr: node.chMgr, + } + } + + if err := node.sched.dcQueue.Enqueue(enqueuedTask); err != nil { log.Warn( rpcFailedToEnqueue(method), zap.Error(err)) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 12cb84410b..241ec76783 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -400,19 +401,21 @@ func (node *Proxy) Start() error { } log.Debug("start id allocator done", zap.String("role", typeutil.ProxyRole)) - if err := node.segAssigner.Start(); err != nil { - log.Warn("failed to start segment id assigner", zap.String("role", typeutil.ProxyRole), zap.Error(err)) - return err - } - log.Debug("start segment id assigner done", zap.String("role", typeutil.ProxyRole)) + if !streamingutil.IsStreamingServiceEnabled() { + if err := node.segAssigner.Start(); err != nil { + log.Warn("failed to start segment id assigner", zap.String("role", typeutil.ProxyRole), zap.Error(err)) + return err + } + log.Debug("start segment id assigner done", zap.String("role", typeutil.ProxyRole)) - if err := node.chTicker.start(); err != nil { - log.Warn("failed to start channels time ticker", zap.String("role", typeutil.ProxyRole), zap.Error(err)) - return err - } - log.Debug("start channels time ticker done", zap.String("role", typeutil.ProxyRole)) + if err := node.chTicker.start(); err != nil { + log.Warn("failed to start channels time ticker", zap.String("role", typeutil.ProxyRole), zap.Error(err)) + return err + } + log.Debug("start channels time ticker done", zap.String("role", typeutil.ProxyRole)) - node.sendChannelsTimeTickLoop() + node.sendChannelsTimeTickLoop() + } // Start callbacks for _, cb := range node.startCallbacks { @@ -440,22 +443,24 @@ func (node *Proxy) Stop() error { log.Info("close id allocator", zap.String("role", typeutil.ProxyRole)) } - if node.segAssigner != nil { - node.segAssigner.Close() - log.Info("close segment id assigner", zap.String("role", typeutil.ProxyRole)) - } - if node.sched != nil { node.sched.Close() log.Info("close scheduler", zap.String("role", typeutil.ProxyRole)) } - if node.chTicker != nil { - err := node.chTicker.close() - if err != nil { - return err + if !streamingutil.IsStreamingServiceEnabled() { + if node.segAssigner != nil { + node.segAssigner.Close() + log.Info("close segment id assigner", zap.String("role", typeutil.ProxyRole)) + } + + if node.chTicker != nil { + err := node.chTicker.close() + if err != nil { + return err + } + log.Info("close channels time ticker", zap.String("role", typeutil.ProxyRole)) } - log.Info("close channels time ticker", zap.String("role", typeutil.ProxyRole)) } for _, cb := range node.closeCallbacks { diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 94ace67712..03d743341a 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -28,9 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" @@ -1434,107 +1432,6 @@ func (t *showPartitionsTask) PostExecute(ctx context.Context) error { return nil } -type flushTask struct { - baseTask - Condition - *milvuspb.FlushRequest - ctx context.Context - dataCoord types.DataCoordClient - result *milvuspb.FlushResponse - - replicateMsgStream msgstream.MsgStream -} - -func (t *flushTask) TraceCtx() context.Context { - return t.ctx -} - -func (t *flushTask) ID() UniqueID { - return t.Base.MsgID -} - -func (t *flushTask) SetID(uid UniqueID) { - t.Base.MsgID = uid -} - -func (t *flushTask) Name() string { - return FlushTaskName -} - -func (t *flushTask) Type() commonpb.MsgType { - return t.Base.MsgType -} - -func (t *flushTask) BeginTs() Timestamp { - return t.Base.Timestamp -} - -func (t *flushTask) EndTs() Timestamp { - return t.Base.Timestamp -} - -func (t *flushTask) SetTs(ts Timestamp) { - t.Base.Timestamp = ts -} - -func (t *flushTask) OnEnqueue() error { - if t.Base == nil { - t.Base = commonpbutil.NewMsgBase() - } - t.Base.MsgType = commonpb.MsgType_Flush - t.Base.SourceID = paramtable.GetNodeID() - return nil -} - -func (t *flushTask) PreExecute(ctx context.Context) error { - return nil -} - -func (t *flushTask) Execute(ctx context.Context) error { - coll2Segments := make(map[string]*schemapb.LongArray) - flushColl2Segments := make(map[string]*schemapb.LongArray) - coll2SealTimes := make(map[string]int64) - coll2FlushTs := make(map[string]Timestamp) - channelCps := make(map[string]*msgpb.MsgPosition) - for _, collName := range t.CollectionNames { - collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collName) - if err != nil { - return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) - } - flushReq := &datapb.FlushRequest{ - Base: commonpbutil.UpdateMsgBase( - t.Base, - commonpbutil.WithMsgType(commonpb.MsgType_Flush), - ), - CollectionID: collID, - } - resp, err := t.dataCoord.Flush(ctx, flushReq) - if err = merr.CheckRPCCall(resp, err); err != nil { - return fmt.Errorf("failed to call flush to data coordinator: %s", err.Error()) - } - coll2Segments[collName] = &schemapb.LongArray{Data: resp.GetSegmentIDs()} - flushColl2Segments[collName] = &schemapb.LongArray{Data: resp.GetFlushSegmentIDs()} - coll2SealTimes[collName] = resp.GetTimeOfSeal() - coll2FlushTs[collName] = resp.GetFlushTs() - channelCps = resp.GetChannelCps() - } - SendReplicateMessagePack(ctx, t.replicateMsgStream, t.FlushRequest) - t.result = &milvuspb.FlushResponse{ - Status: merr.Success(), - DbName: t.GetDbName(), - CollSegIDs: coll2Segments, - FlushCollSegIDs: flushColl2Segments, - CollSealTimes: coll2SealTimes, - CollFlushTs: coll2FlushTs, - ChannelCps: channelCps, - } - return nil -} - -func (t *flushTask) PostExecute(ctx context.Context) error { - return nil -} - type loadCollectionTask struct { baseTask Condition diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index 4e2e202325..95e15b92ea 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -22,6 +22,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -63,6 +64,8 @@ type deleteTask struct { // result count int64 allQueryCnt int64 + + sessionTS Timestamp } func (dt *deleteTask) TraceCtx() context.Context { @@ -142,28 +145,19 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { return err } - hashValues := typeutil.HashPK2Channels(dt.primaryKeys, dt.vChannels) - // repack delete msg by dmChannel - result := make(map[uint32]msgstream.TsMsg) - numRows := int64(0) - for index, key := range hashValues { - vchannel := dt.vChannels[key] - _, ok := result[key] - if !ok { - deleteMsg, err := dt.newDeleteMsg(ctx) - if err != nil { - return err - } - deleteMsg.ShardName = vchannel - result[key] = deleteMsg - } - curMsg := result[key].(*msgstream.DeleteMsg) - curMsg.HashValues = append(curMsg.HashValues, hashValues[index]) - curMsg.Timestamps = append(curMsg.Timestamps, dt.ts) - - typeutil.AppendIDs(curMsg.PrimaryKeys, dt.primaryKeys, index) - curMsg.NumRows++ - numRows++ + result, numRows, err := repackDeleteMsgByHash( + ctx, + dt.primaryKeys, + dt.vChannels, + dt.idAllocator, + dt.ts, + dt.collectionID, + dt.req.GetCollectionName(), + dt.partitionID, + dt.req.GetPartitionName(), + ) + if err != nil { + return err } // send delete request to log broker @@ -189,6 +183,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { if err != nil { return err } + dt.sessionTS = dt.ts dt.count += numRows return nil } @@ -197,30 +192,84 @@ func (dt *deleteTask) PostExecute(ctx context.Context) error { return nil } -func (dt *deleteTask) newDeleteMsg(ctx context.Context) (*msgstream.DeleteMsg, error) { - msgid, err := dt.idAllocator.AllocOne() +func repackDeleteMsgByHash( + ctx context.Context, + primaryKeys *schemapb.IDs, + vChannels []string, + idAllocator allocator.Interface, + ts uint64, + collectionID int64, + collectionName string, + partitionID int64, + partitionName string, +) (map[uint32]*msgstream.DeleteMsg, int64, error) { + hashValues := typeutil.HashPK2Channels(primaryKeys, vChannels) + // repack delete msg by dmChannel + result := make(map[uint32]*msgstream.DeleteMsg) + numRows := int64(0) + for index, key := range hashValues { + vchannel := vChannels[key] + _, ok := result[key] + if !ok { + deleteMsg, err := newDeleteMsg( + ctx, + idAllocator, + ts, + collectionID, + collectionName, + partitionID, + partitionName, + ) + if err != nil { + return nil, 0, err + } + deleteMsg.ShardName = vchannel + result[key] = deleteMsg + } + curMsg := result[key] + curMsg.HashValues = append(curMsg.HashValues, hashValues[index]) + curMsg.Timestamps = append(curMsg.Timestamps, ts) + + typeutil.AppendIDs(curMsg.PrimaryKeys, primaryKeys, index) + curMsg.NumRows++ + numRows++ + } + return result, numRows, nil +} + +func newDeleteMsg( + ctx context.Context, + idAllocator allocator.Interface, + ts uint64, + collectionID int64, + collectionName string, + partitionID int64, + partitionName string, +) (*msgstream.DeleteMsg, error) { + msgid, err := idAllocator.AllocOne() if err != nil { return nil, errors.Wrap(err, "failed to allocate MsgID of delete") } + sliceRequest := &msgpb.DeleteRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Delete), + // msgid of delete msg must be set + // or it will be seen as duplicated msg in mq + commonpbutil.WithMsgID(msgid), + commonpbutil.WithTimeStamp(ts), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + CollectionID: collectionID, + PartitionID: partitionID, + CollectionName: collectionName, + PartitionName: partitionName, + PrimaryKeys: &schemapb.IDs{}, + } return &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{ Ctx: ctx, }, - DeleteRequest: &msgpb.DeleteRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_Delete), - // msgid of delete msg must be set - // or it will be seen as duplicated msg in mq - commonpbutil.WithMsgID(msgid), - commonpbutil.WithTimeStamp(dt.ts), - commonpbutil.WithSourceID(paramtable.GetNodeID()), - ), - CollectionID: dt.collectionID, - PartitionID: dt.partitionID, - CollectionName: dt.req.GetCollectionName(), - PartitionName: dt.req.GetPartitionName(), - PrimaryKeys: &schemapb.IDs{}, - }, + DeleteRequest: sliceRequest, }, nil } @@ -254,6 +303,7 @@ type deleteRunner struct { queue *dmTaskQueue allQueryCnt atomic.Int64 + sessionTS atomic.Uint64 } func (dr *deleteRunner) Init(ctx context.Context) error { @@ -346,7 +396,7 @@ func (dr *deleteRunner) Run(ctx context.Context) error { } func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs) (*deleteTask, error) { - task := &deleteTask{ + dt := &deleteTask{ ctx: ctx, Condition: NewTaskCondition(ctx), req: dr.req, @@ -359,13 +409,17 @@ func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs) vChannels: dr.vChannels, primaryKeys: primaryKeys, } + var enqueuedTask task = dt + if streamingutil.IsStreamingServiceEnabled() { + enqueuedTask = &deleteTaskByStreamingService{deleteTask: dt} + } - if err := dr.queue.Enqueue(task); err != nil { + if err := dr.queue.Enqueue(enqueuedTask); err != nil { log.Error("Failed to enqueue delete task: " + err.Error()) return nil, err } - return task, nil + return dt, nil } // getStreamingQueryAndDelteFunc return query function used by LBPolicy @@ -447,6 +501,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe }() var allQueryCnt int64 // wait all task finish + var sessionTS uint64 for task := range taskCh { err := task.WaitToFinish() if err != nil { @@ -454,6 +509,9 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe } dr.count.Add(task.count) allQueryCnt += task.allQueryCnt + if sessionTS < task.sessionTS { + sessionTS = task.sessionTS + } } // query or produce task failed @@ -461,6 +519,7 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe return receiveErr } dr.allQueryCnt.Add(allQueryCnt) + dr.sessionTS.Store(sessionTS) return nil } } @@ -523,6 +582,7 @@ func (dr *deleteRunner) complexDelete(ctx context.Context, plan *planpb.PlanNode exec: dr.getStreamingQueryAndDelteFunc(plan), }) dr.result.DeleteCnt = dr.count.Load() + dr.result.Timestamp = dr.sessionTS.Load() if err != nil { log.Warn("fail to execute complex delete", zap.Int64("deleteCnt", dr.result.GetDeleteCnt()), @@ -550,6 +610,7 @@ func (dr *deleteRunner) simpleDelete(ctx context.Context, pk *schemapb.IDs, numR err = task.WaitToFinish() if err == nil { dr.result.DeleteCnt = task.count + dr.result.Timestamp = task.sessionTS } return err } diff --git a/internal/proxy/task_delete_streaming.go b/internal/proxy/task_delete_streaming.go new file mode 100644 index 0000000000..cc46130ea8 --- /dev/null +++ b/internal/proxy/task_delete_streaming.go @@ -0,0 +1,79 @@ +package proxy + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type deleteTaskByStreamingService struct { + *deleteTask +} + +// Execute is a function to delete task by streaming service +// we only overwrite the Execute function +func (dt *deleteTaskByStreamingService) Execute(ctx context.Context) (err error) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute") + defer sp.End() + + if len(dt.req.GetExpr()) == 0 { + return merr.WrapErrParameterInvalid("valid expr", "empty expr", "invalid expression") + } + + dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID())) + result, numRows, err := repackDeleteMsgByHash( + ctx, + dt.primaryKeys, + dt.vChannels, + dt.idAllocator, + dt.ts, + dt.collectionID, + dt.req.GetCollectionName(), + dt.partitionID, + dt.req.GetPartitionName(), + ) + if err != nil { + return err + } + + var msgs []message.MutableMessage + for hashKey, deleteMsg := range result { + vchannel := dt.vChannels[hashKey] + msg, err := message.NewDeleteMessageBuilderV1(). + WithHeader(&message.DeleteMessageHeader{ + CollectionId: dt.collectionID, + }). + WithBody(deleteMsg.DeleteRequest). + WithVChannel(vchannel). + BuildMutable() + if err != nil { + return err + } + msgs = append(msgs, msg) + } + + log.Debug("send delete request to virtual channels", + zap.String("collectionName", dt.req.GetCollectionName()), + zap.Int64("collectionID", dt.collectionID), + zap.Strings("virtual_channels", dt.vChannels), + zap.Int64("taskID", dt.ID()), + zap.Duration("prepare duration", dt.tr.RecordSpan())) + + resp := streaming.WAL().AppendMessages(ctx, msgs...) + if resp.UnwrapFirstError(); err != nil { + log.Warn("append messages to wal failed", zap.Error(err)) + return err + } + dt.sessionTS = resp.MaxTimeTick() + dt.count += numRows + return nil +} diff --git a/internal/proxy/task_flush.go b/internal/proxy/task_flush.go new file mode 100644 index 0000000000..44beeccb72 --- /dev/null +++ b/internal/proxy/task_flush.go @@ -0,0 +1,134 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type flushTask struct { + baseTask + Condition + *milvuspb.FlushRequest + ctx context.Context + dataCoord types.DataCoordClient + result *milvuspb.FlushResponse + + replicateMsgStream msgstream.MsgStream +} + +func (t *flushTask) TraceCtx() context.Context { + return t.ctx +} + +func (t *flushTask) ID() UniqueID { + return t.Base.MsgID +} + +func (t *flushTask) SetID(uid UniqueID) { + t.Base.MsgID = uid +} + +func (t *flushTask) Name() string { + return FlushTaskName +} + +func (t *flushTask) Type() commonpb.MsgType { + return t.Base.MsgType +} + +func (t *flushTask) BeginTs() Timestamp { + return t.Base.Timestamp +} + +func (t *flushTask) EndTs() Timestamp { + return t.Base.Timestamp +} + +func (t *flushTask) SetTs(ts Timestamp) { + t.Base.Timestamp = ts +} + +func (t *flushTask) OnEnqueue() error { + if t.Base == nil { + t.Base = commonpbutil.NewMsgBase() + } + t.Base.MsgType = commonpb.MsgType_Flush + t.Base.SourceID = paramtable.GetNodeID() + return nil +} + +func (t *flushTask) PreExecute(ctx context.Context) error { + return nil +} + +func (t *flushTask) Execute(ctx context.Context) error { + coll2Segments := make(map[string]*schemapb.LongArray) + flushColl2Segments := make(map[string]*schemapb.LongArray) + coll2SealTimes := make(map[string]int64) + coll2FlushTs := make(map[string]Timestamp) + channelCps := make(map[string]*msgpb.MsgPosition) + for _, collName := range t.CollectionNames { + collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collName) + if err != nil { + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + flushReq := &datapb.FlushRequest{ + Base: commonpbutil.UpdateMsgBase( + t.Base, + commonpbutil.WithMsgType(commonpb.MsgType_Flush), + ), + CollectionID: collID, + } + resp, err := t.dataCoord.Flush(ctx, flushReq) + if err = merr.CheckRPCCall(resp, err); err != nil { + return fmt.Errorf("failed to call flush to data coordinator: %s", err.Error()) + } + coll2Segments[collName] = &schemapb.LongArray{Data: resp.GetSegmentIDs()} + flushColl2Segments[collName] = &schemapb.LongArray{Data: resp.GetFlushSegmentIDs()} + coll2SealTimes[collName] = resp.GetTimeOfSeal() + coll2FlushTs[collName] = resp.GetFlushTs() + channelCps = resp.GetChannelCps() + } + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.FlushRequest) + t.result = &milvuspb.FlushResponse{ + Status: merr.Success(), + DbName: t.GetDbName(), + CollSegIDs: coll2Segments, + FlushCollSegIDs: flushColl2Segments, + CollSealTimes: coll2SealTimes, + CollFlushTs: coll2FlushTs, + ChannelCps: channelCps, + } + return nil +} + +func (t *flushTask) PostExecute(ctx context.Context) error { + return nil +} diff --git a/internal/proxy/task_flush_streaming.go b/internal/proxy/task_flush_streaming.go new file mode 100644 index 0000000000..e0cc8625ad --- /dev/null +++ b/internal/proxy/task_flush_streaming.go @@ -0,0 +1,148 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "context" + "fmt" + + "github.com/pingcap/log" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) + +type flushTaskByStreamingService struct { + *flushTask + chMgr channelsMgr +} + +func (t *flushTaskByStreamingService) Execute(ctx context.Context) error { + coll2Segments := make(map[string]*schemapb.LongArray) + flushColl2Segments := make(map[string]*schemapb.LongArray) + coll2SealTimes := make(map[string]int64) + coll2FlushTs := make(map[string]Timestamp) + channelCps := make(map[string]*msgpb.MsgPosition) + + flushTs := t.BeginTs() + log.Info("flushTaskByStreamingService.Execute", zap.Int("collectionNum", len(t.CollectionNames)), zap.Uint64("flushTs", flushTs)) + timeOfSeal, _ := tsoutil.ParseTS(flushTs) + for _, collName := range t.CollectionNames { + collID, err := globalMetaCache.GetCollectionID(t.ctx, t.DbName, collName) + if err != nil { + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + vchannels, err := t.chMgr.getVChannels(collID) + if err != nil { + return err + } + onFlushSegmentIDs := make([]int64, 0) + + // Ask the streamingnode to flush segments. + for _, vchannel := range vchannels { + segmentIDs, err := t.sendManualFlushToWAL(ctx, collID, vchannel, flushTs) + if err != nil { + return err + } + onFlushSegmentIDs = append(onFlushSegmentIDs, segmentIDs...) + } + + // Ask datacoord to get flushed segment infos. + flushReq := &datapb.FlushRequest{ + Base: commonpbutil.UpdateMsgBase( + t.Base, + commonpbutil.WithMsgType(commonpb.MsgType_Flush), + ), + CollectionID: collID, + } + resp, err := t.dataCoord.Flush(ctx, flushReq) + if err = merr.CheckRPCCall(resp, err); err != nil { + return fmt.Errorf("failed to call flush to data coordinator: %s", err.Error()) + } + + // Remove the flushed segments from onFlushSegmentIDs + for _, segID := range resp.GetFlushSegmentIDs() { + for i, id := range onFlushSegmentIDs { + if id == segID { + onFlushSegmentIDs = append(onFlushSegmentIDs[:i], onFlushSegmentIDs[i+1:]...) + break + } + } + } + + coll2Segments[collName] = &schemapb.LongArray{Data: onFlushSegmentIDs} + flushColl2Segments[collName] = &schemapb.LongArray{Data: resp.GetFlushSegmentIDs()} + coll2SealTimes[collName] = timeOfSeal.Unix() + coll2FlushTs[collName] = flushTs + channelCps = resp.GetChannelCps() + } + // TODO: refactor to use streaming service + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.FlushRequest) + t.result = &milvuspb.FlushResponse{ + Status: merr.Success(), + DbName: t.GetDbName(), + CollSegIDs: coll2Segments, + FlushCollSegIDs: flushColl2Segments, + CollSealTimes: coll2SealTimes, + CollFlushTs: coll2FlushTs, + ChannelCps: channelCps, + } + return nil +} + +// sendManualFlushToWAL sends a manual flush message to WAL. +func (t *flushTaskByStreamingService) sendManualFlushToWAL(ctx context.Context, collID int64, vchannel string, flushTs uint64) ([]int64, error) { + logger := log.With(zap.Int64("collectionID", collID), zap.String("vchannel", vchannel)) + flushMsg, err := message.NewManualFlushMessageBuilderV2(). + WithVChannel(vchannel). + WithHeader(&message.ManualFlushMessageHeader{ + CollectionId: collID, + FlushTs: flushTs, + }). + WithBody(&message.ManualFlushMessageBody{}). + BuildMutable() + if err != nil { + logger.Warn("build manual flush message failed", zap.Error(err)) + return nil, err + } + + appendResult, err := streaming.WAL().RawAppend(ctx, flushMsg, streaming.AppendOption{ + BarrierTimeTick: flushTs, + }) + if err != nil { + logger.Warn("append manual flush message to wal failed", zap.Error(err)) + return nil, err + } + + var flushMsgResponse message.ManualFlushExtraResponse + if err := appendResult.GetExtra(&flushMsgResponse); err != nil { + logger.Warn("get extra from append result failed", zap.Error(err)) + return nil, err + } + logger.Info("append manual flush message to wal successfully") + + return flushMsgResponse.GetSegmentIds(), nil +} diff --git a/internal/proxy/task_insert_streaming.go b/internal/proxy/task_insert_streaming.go new file mode 100644 index 0000000000..a452816f12 --- /dev/null +++ b/internal/proxy/task_insert_streaming.go @@ -0,0 +1,200 @@ +package proxy + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type insertTaskByStreamingService struct { + *insertTask +} + +// we only overwrite the Execute function +func (it *insertTaskByStreamingService) Execute(ctx context.Context) error { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute") + defer sp.End() + + tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert streaming %d", it.ID())) + + collectionName := it.insertMsg.CollectionName + collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.GetDbName(), collectionName) + log := log.Ctx(ctx) + if err != nil { + log.Warn("fail to get collection id", zap.Error(err)) + return err + } + it.insertMsg.CollectionID = collID + + getCacheDur := tr.RecordSpan() + channelNames, err := it.chMgr.getVChannels(collID) + if err != nil { + log.Warn("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err)) + it.result.Status = merr.Status(err) + return err + } + + log.Debug("send insert request to virtual channels", + zap.String("partition", it.insertMsg.GetPartitionName()), + zap.Int64("collectionID", collID), + zap.Strings("virtual_channels", channelNames), + zap.Int64("task_id", it.ID()), + zap.Bool("is_parition_key", it.partitionKeys != nil), + zap.Duration("get cache duration", getCacheDur)) + + // start to repack insert data + var msgs []message.MutableMessage + if it.partitionKeys == nil { + msgs, err = repackInsertDataForStreamingService(it.TraceCtx(), channelNames, it.insertMsg, it.result) + } else { + msgs, err = repackInsertDataWithPartitionKeyForStreamingService(it.TraceCtx(), channelNames, it.insertMsg, it.result, it.partitionKeys) + } + if err != nil { + log.Warn("assign segmentID and repack insert data failed", zap.Error(err)) + it.result.Status = merr.Status(err) + return err + } + resp := streaming.WAL().AppendMessages(ctx, msgs...) + if err := resp.UnwrapFirstError(); err != nil { + log.Warn("append messages to wal failed", zap.Error(err)) + it.result.Status = merr.Status(err) + } + // Update result.Timestamp for session consistency. + it.result.Timestamp = resp.MaxTimeTick() + return nil +} + +func repackInsertDataForStreamingService( + ctx context.Context, + channelNames []string, + insertMsg *msgstream.InsertMsg, + result *milvuspb.MutationResult, +) ([]message.MutableMessage, error) { + messages := make([]message.MutableMessage, 0) + + channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg) + for channel, rowOffsets := range channel2RowOffsets { + partitionName := insertMsg.PartitionName + partitionID, err := globalMetaCache.GetPartitionID(ctx, insertMsg.GetDbName(), insertMsg.CollectionName, partitionName) + if err != nil { + return nil, err + } + // segment id is assigned at streaming node. + msgs, err := genInsertMsgsByPartition(ctx, 0, partitionID, partitionName, rowOffsets, channel, insertMsg) + if err != nil { + return nil, err + } + for _, msg := range msgs { + newMsg, err := message.NewInsertMessageBuilderV1(). + WithVChannel(channel). + WithHeader(&message.InsertMessageHeader{ + CollectionId: insertMsg.CollectionID, + Partitions: []*message.PartitionSegmentAssignment{ + { + PartitionId: partitionID, + Rows: uint64(len(rowOffsets)), + BinarySize: 0, // TODO: current not used, message estimate size is used. + }, + }, + }). + WithBody(msg.(*msgstream.InsertMsg).InsertRequest). + BuildMutable() + if err != nil { + return nil, err + } + messages = append(messages, newMsg) + } + } + return messages, nil +} + +func repackInsertDataWithPartitionKeyForStreamingService( + ctx context.Context, + channelNames []string, + insertMsg *msgstream.InsertMsg, + result *milvuspb.MutationResult, + partitionKeys *schemapb.FieldData, +) ([]message.MutableMessage, error) { + messages := make([]message.MutableMessage, 0) + + channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg) + partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, insertMsg.GetDbName(), insertMsg.CollectionName) + if err != nil { + log.Warn("get default partition names failed in partition key mode", + zap.String("collectionName", insertMsg.CollectionName), + zap.Error(err)) + return nil, err + } + + // Get partition ids + partitionIDs := make(map[string]int64, 0) + for _, partitionName := range partitionNames { + partitionID, err := globalMetaCache.GetPartitionID(ctx, insertMsg.GetDbName(), insertMsg.CollectionName, partitionName) + if err != nil { + log.Warn("get partition id failed", + zap.String("collectionName", insertMsg.CollectionName), + zap.String("partitionName", partitionName), + zap.Error(err)) + return nil, err + } + partitionIDs[partitionName] = partitionID + } + + hashValues, err := typeutil.HashKey2Partitions(partitionKeys, partitionNames) + if err != nil { + log.Warn("has partition keys to partitions failed", + zap.String("collectionName", insertMsg.CollectionName), + zap.Error(err)) + return nil, err + } + for channel, rowOffsets := range channel2RowOffsets { + partition2RowOffsets := make(map[string][]int) + for _, idx := range rowOffsets { + partitionName := partitionNames[hashValues[idx]] + if _, ok := partition2RowOffsets[partitionName]; !ok { + partition2RowOffsets[partitionName] = []int{} + } + partition2RowOffsets[partitionName] = append(partition2RowOffsets[partitionName], idx) + } + + for partitionName, rowOffsets := range partition2RowOffsets { + msgs, err := genInsertMsgsByPartition(ctx, 0, partitionIDs[partitionName], partitionName, rowOffsets, channel, insertMsg) + if err != nil { + return nil, err + } + for _, msg := range msgs { + newMsg, err := message.NewInsertMessageBuilderV1(). + WithVChannel(channel). + WithHeader(&message.InsertMessageHeader{ + CollectionId: insertMsg.CollectionID, + Partitions: []*message.PartitionSegmentAssignment{ + { + PartitionId: partitionIDs[partitionName], + Rows: uint64(len(rowOffsets)), + BinarySize: 0, // TODO: current not used, message estimate size is used. + }, + }, + }). + WithBody(msg.(*msgstream.InsertMsg).InsertRequest). + BuildMutable() + if err != nil { + return nil, err + } + messages = append(messages, newMsg) + } + } + } + return messages, nil +} diff --git a/internal/proxy/task_upsert_streaming.go b/internal/proxy/task_upsert_streaming.go new file mode 100644 index 0000000000..2aba567034 --- /dev/null +++ b/internal/proxy/task_upsert_streaming.go @@ -0,0 +1,148 @@ +package proxy + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type upsertTaskByStreamingService struct { + *upsertTask +} + +func (ut *upsertTaskByStreamingService) Execute(ctx context.Context) error { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute") + defer sp.End() + log := log.Ctx(ctx).With(zap.String("collectionName", ut.req.CollectionName)) + + insertMsgs, err := ut.packInsertMessage(ctx) + if err != nil { + log.Warn("pack insert message failed", zap.Error(err)) + return err + } + deleteMsgs, err := ut.packDeleteMessage(ctx) + if err != nil { + log.Warn("pack delete message failed", zap.Error(err)) + return err + } + + messages := append(insertMsgs, deleteMsgs...) + resp := streaming.WAL().AppendMessages(ctx, messages...) + if err := resp.UnwrapFirstError(); err != nil { + log.Warn("append messages to wal failed", zap.Error(err)) + return err + } + // Update result.Timestamp for session consistency. + ut.result.Timestamp = resp.MaxTimeTick() + return nil +} + +func (ut *upsertTaskByStreamingService) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) { + tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", ut.ID())) + defer tr.Elapse("insert execute done when insertExecute") + + collectionName := ut.upsertMsg.InsertMsg.CollectionName + collID, err := globalMetaCache.GetCollectionID(ctx, ut.req.GetDbName(), collectionName) + if err != nil { + return nil, err + } + ut.upsertMsg.InsertMsg.CollectionID = collID + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collID)) + getCacheDur := tr.RecordSpan() + + getMsgStreamDur := tr.RecordSpan() + channelNames, err := ut.chMgr.getVChannels(collID) + if err != nil { + log.Warn("get vChannels failed when insertExecute", + zap.Error(err)) + ut.result.Status = merr.Status(err) + return nil, err + } + + log.Debug("send insert request to virtual channels when insertExecute", + zap.String("collection", ut.req.GetCollectionName()), + zap.String("partition", ut.req.GetPartitionName()), + zap.Int64("collection_id", collID), + zap.Strings("virtual_channels", channelNames), + zap.Int64("task_id", ut.ID()), + zap.Duration("get cache duration", getCacheDur), + zap.Duration("get msgStream duration", getMsgStreamDur)) + + // start to repack insert data + var msgs []message.MutableMessage + if ut.partitionKeys == nil { + msgs, err = repackInsertDataForStreamingService(ut.TraceCtx(), channelNames, ut.upsertMsg.InsertMsg, ut.result) + } else { + msgs, err = repackInsertDataWithPartitionKeyForStreamingService(ut.TraceCtx(), channelNames, ut.upsertMsg.InsertMsg, ut.result, ut.partitionKeys) + } + if err != nil { + log.Warn("assign segmentID and repack insert data failed", zap.Error(err)) + ut.result.Status = merr.Status(err) + return nil, err + } + return msgs, nil +} + +func (it *upsertTaskByStreamingService) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) { + tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID())) + collID := it.upsertMsg.DeleteMsg.CollectionID + it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIds + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collID)) + // hash primary keys to channels + vChannels, err := it.chMgr.getVChannels(collID) + if err != nil { + log.Warn("get vChannels failed when deleteExecute", zap.Error(err)) + it.result.Status = merr.Status(err) + return nil, err + } + result, numRows, err := repackDeleteMsgByHash( + ctx, + it.upsertMsg.DeleteMsg.PrimaryKeys, + vChannels, + it.idAllocator, + it.BeginTs(), + it.upsertMsg.DeleteMsg.CollectionID, + it.upsertMsg.DeleteMsg.CollectionName, + it.upsertMsg.DeleteMsg.PartitionID, + it.upsertMsg.DeleteMsg.PartitionName, + ) + if err != nil { + return nil, err + } + + var msgs []message.MutableMessage + for hashKey, deleteMsg := range result { + vchannel := vChannels[hashKey] + msg, err := message.NewDeleteMessageBuilderV1(). + WithHeader(&message.DeleteMessageHeader{ + CollectionId: it.upsertMsg.DeleteMsg.CollectionID, + }). + WithBody(deleteMsg.DeleteRequest). + WithVChannel(vchannel). + BuildMutable() + if err != nil { + return nil, err + } + msgs = append(msgs, msg) + } + + log.Debug("Proxy Upsert deleteExecute done", + zap.Int64("collection_id", collID), + zap.Strings("virtual_channels", vChannels), + zap.Int64("taskID", it.ID()), + zap.Int64("numRows", numRows), + zap.Duration("prepare duration", tr.ElapseSpan())) + + return msgs, nil +} diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index 6c92fb066b..27c0e225d2 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" @@ -38,11 +39,15 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" mqcommon "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" + "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -750,14 +755,10 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context, return nil } -func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) { - log := sd.getLogger(ctx).With( - zap.String("channel", position.ChannelName), - zap.Int64("segmentID", candidate.ID()), - ) +func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) { stream, err := sd.factory.NewTtMsgStream(ctx) if err != nil { - return nil, err + return nil, nil, err } defer stream.Close() vchannelName := position.ChannelName @@ -771,15 +772,57 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position log.Info("from dml check point load delete", zap.Any("position", position), zap.String("vChannel", vchannelName), zap.String("subName", subName), zap.Time("positionTs", ts)) err = stream.AsConsumer(context.TODO(), []string{pChannelName}, subName, mqcommon.SubscriptionPositionUnknown) if err != nil { - return nil, err + return nil, stream.Close, err } - ts = time.Now() err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}, false) + if err != nil { + return nil, stream.Close, err + } + return stream.Chan(), stream.Close, nil +} + +func (sd *shardDelegator) createDeleteStreamFromStreamingService(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) { + handler := adaptor.NewMsgPackAdaptorHandler() + s := streaming.WAL().Read(ctx, streaming.ReadOption{ + VChannel: position.GetChannelName(), + DeliverPolicy: options.DeliverPolicyStartFrom( + adaptor.MustGetMessageIDFromMQWrapperIDBytes("pulsar", position.GetMsgID()), + ), + DeliverFilters: []options.DeliverFilter{ + // only deliver message which timestamp >= position.Timestamp + options.DeliverFilterTimeTickGTE(position.GetTimestamp()), + // only delete message + options.DeliverFilterMessageType(message.MessageTypeDelete), + }, + MessageHandler: handler, + }) + return handler.Chan(), s.Close, nil +} + +func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) { + log := sd.getLogger(ctx).With( + zap.String("channel", position.ChannelName), + zap.Int64("segmentID", candidate.ID()), + ) + pChannelName := funcutil.ToPhysicalChannel(position.ChannelName) + + var ch <-chan *msgstream.MsgPack + var closer func() + var err error + if streamingutil.IsStreamingServiceEnabled() { + ch, closer, err = sd.createDeleteStreamFromStreamingService(ctx, position) + } else { + ch, closer, err = sd.createStreamFromMsgStream(ctx, position) + } + if closer != nil { + defer closer() + } if err != nil { return nil, err } + start := time.Now() result := &storage.DeleteData{} hasMore := true for hasMore { @@ -787,7 +830,7 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position case <-ctx.Done(): log.Debug("read delta msg from seek position done", zap.Error(ctx.Err())) return nil, ctx.Err() - case msgPack, ok := <-stream.Chan(): + case msgPack, ok := <-ch: if !ok { err = fmt.Errorf("stream channel closed, pChannelName=%v, msgID=%v", pChannelName, position.GetMsgID()) log.Warn("fail to read delta msg", @@ -835,7 +878,7 @@ func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position } } } - log.Info("successfully read delete from stream ", zap.Duration("time spent", time.Since(ts))) + log.Info("successfully read delete from stream ", zap.Duration("time spent", time.Since(start))) return result, nil } diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index 8246beb4d3..2c5df5d03d 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -30,12 +30,16 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/metastore/model" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/util/proxyutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" ms "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -400,13 +404,6 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error { } func (t *createCollectionTask) genCreateCollectionMsg(ctx context.Context, ts uint64) *ms.MsgPack { - collectionID := t.collID - partitionIDs := t.partIDs - // error won't happen here. - marshaledSchema, _ := proto.Marshal(t.schema) - pChannels := t.channels.physicalChannels - vChannels := t.channels.virtualChannels - msgPack := ms.MsgPack{} msg := &ms.CreateCollectionMsg{ BaseMsg: ms.BaseMsg{ @@ -415,28 +412,78 @@ func (t *createCollectionTask) genCreateCollectionMsg(ctx context.Context, ts ui EndTimestamp: ts, HashValues: []uint32{0}, }, - CreateCollectionRequest: &msgpb.CreateCollectionRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_CreateCollection), - commonpbutil.WithTimeStamp(ts), - ), - CollectionID: collectionID, - PartitionIDs: partitionIDs, - Schema: marshaledSchema, - VirtualChannelNames: vChannels, - PhysicalChannelNames: pChannels, - }, + CreateCollectionRequest: t.genCreateCollectionRequest(), } msgPack.Msgs = append(msgPack.Msgs, msg) return &msgPack } +func (t *createCollectionTask) genCreateCollectionRequest() *msgpb.CreateCollectionRequest { + collectionID := t.collID + partitionIDs := t.partIDs + // error won't happen here. + marshaledSchema, _ := proto.Marshal(t.schema) + pChannels := t.channels.physicalChannels + vChannels := t.channels.virtualChannels + return &msgpb.CreateCollectionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_CreateCollection), + commonpbutil.WithTimeStamp(t.ts), + ), + CollectionID: collectionID, + PartitionIDs: partitionIDs, + Schema: marshaledSchema, + VirtualChannelNames: vChannels, + PhysicalChannelNames: pChannels, + } +} + func (t *createCollectionTask) addChannelsAndGetStartPositions(ctx context.Context, ts uint64) (map[string][]byte, error) { t.core.chanTimeTick.addDmlChannels(t.channels.physicalChannels...) + if streamingutil.IsStreamingServiceEnabled() { + return t.broadcastCreateCollectionMsgIntoStreamingService(ctx, ts) + } msg := t.genCreateCollectionMsg(ctx, ts) return t.core.chanTimeTick.broadcastMarkDmlChannels(t.channels.physicalChannels, msg) } +func (t *createCollectionTask) broadcastCreateCollectionMsgIntoStreamingService(ctx context.Context, ts uint64) (map[string][]byte, error) { + req := t.genCreateCollectionRequest() + // dispatch the createCollectionMsg into all vchannel. + msgs := make([]message.MutableMessage, 0, len(req.VirtualChannelNames)) + for _, vchannel := range req.VirtualChannelNames { + msg, err := message.NewCreateCollectionMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.CreateCollectionMessageHeader{ + CollectionId: req.CollectionID, + PartitionIds: req.GetPartitionIDs(), + }). + WithBody(req). + BuildMutable() + if err != nil { + return nil, err + } + msgs = append(msgs, msg) + } + // send the createCollectionMsg into streaming service. + // ts is used as initial checkpoint at datacoord, + // it must be set as barrier time tick. + // The timetick of create message in wal must be greater than ts, to avoid data read loss at read side. + resps := streaming.WAL().AppendMessagesWithOption(ctx, streaming.AppendOption{ + BarrierTimeTick: ts, + }, msgs...) + if err := resps.UnwrapFirstError(); err != nil { + return nil, err + } + // make the old message stream serialized id. + startPositions := make(map[string][]byte) + for idx, resp := range resps.Responses { + // The key is pchannel here + startPositions[req.PhysicalChannelNames[idx]] = adaptor.MustGetMQWrapperIDFromMessage(resp.AppendResult.MessageID).Serialize() + } + return startPositions, nil +} + func (t *createCollectionTask) getCreateTs() (uint64, error) { replicateInfo := t.Req.GetBase().GetReplicateInfo() if !replicateInfo.GetIsReplicate() { diff --git a/internal/rootcoord/garbage_collector.go b/internal/rootcoord/garbage_collector.go index 0c36d59ee0..0fc9ea58b3 100644 --- a/internal/rootcoord/garbage_collector.go +++ b/internal/rootcoord/garbage_collector.go @@ -167,6 +167,10 @@ func (c *bgGarbageCollector) RemoveCreatingPartition(dbID int64, partition *mode } func (c *bgGarbageCollector) notifyCollectionGc(ctx context.Context, coll *model.Collection) (ddlTs Timestamp, err error) { + if streamingutil.IsStreamingServiceEnabled() { + return c.notifyCollectionGcByStreamingService(ctx, coll) + } + ts, err := c.s.tsoAllocator.GenerateTSO(1) if err != nil { return 0, err @@ -180,15 +184,7 @@ func (c *bgGarbageCollector) notifyCollectionGc(ctx context.Context, coll *model EndTimestamp: ts, HashValues: []uint32{0}, }, - DropCollectionRequest: &msgpb.DropCollectionRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_DropCollection), - commonpbutil.WithTimeStamp(ts), - commonpbutil.WithSourceID(c.s.session.ServerID), - ), - CollectionName: coll.Name, - CollectionID: coll.CollectionID, - }, + DropCollectionRequest: c.generateDropRequest(coll, ts), } msgPack.Msgs = append(msgPack.Msgs, msg) if err := c.s.chanTimeTick.broadcastDmlChannels(coll.PhysicalChannelNames, &msgPack); err != nil { @@ -198,6 +194,42 @@ func (c *bgGarbageCollector) notifyCollectionGc(ctx context.Context, coll *model return ts, nil } +func (c *bgGarbageCollector) generateDropRequest(coll *model.Collection, ts uint64) *msgpb.DropCollectionRequest { + return &msgpb.DropCollectionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DropCollection), + commonpbutil.WithTimeStamp(ts), + commonpbutil.WithSourceID(c.s.session.ServerID), + ), + CollectionName: coll.Name, + CollectionID: coll.CollectionID, + } +} + +func (c *bgGarbageCollector) notifyCollectionGcByStreamingService(ctx context.Context, coll *model.Collection) (uint64, error) { + req := c.generateDropRequest(coll, 0) // ts is given by streamingnode. + + msgs := make([]message.MutableMessage, 0, len(coll.VirtualChannelNames)) + for _, vchannel := range coll.VirtualChannelNames { + msg, err := message.NewDropCollectionMessageBuilderV1(). + WithVChannel(vchannel). + WithHeader(&message.DropCollectionMessageHeader{ + CollectionId: coll.CollectionID, + }). + WithBody(req). + BuildMutable() + if err != nil { + return 0, err + } + msgs = append(msgs, msg) + } + resp := streaming.WAL().AppendMessages(ctx, msgs...) + if err := resp.UnwrapFirstError(); err != nil { + return 0, err + } + return resp.MaxTimeTick(), nil +} + func (c *bgGarbageCollector) notifyPartitionGc(ctx context.Context, pChannels []string, partition *model.Partition) (ddlTs Timestamp, err error) { ts, err := c.s.tsoAllocator.GenerateTSO(1) if err != nil { @@ -232,14 +264,10 @@ func (c *bgGarbageCollector) notifyPartitionGc(ctx context.Context, pChannels [] } func (c *bgGarbageCollector) notifyPartitionGcByStreamingService(ctx context.Context, vchannels []string, partition *model.Partition) (uint64, error) { - ts, err := c.s.tsoAllocator.GenerateTSO(1) - if err != nil { - return 0, err - } req := &msgpb.DropPartitionRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_DropPartition), - commonpbutil.WithTimeStamp(ts), + commonpbutil.WithTimeStamp(0), // Timetick is given by streamingnode. commonpbutil.WithSourceID(c.s.session.ServerID), ), PartitionName: partition.PartitionName, @@ -263,12 +291,11 @@ func (c *bgGarbageCollector) notifyPartitionGcByStreamingService(ctx context.Con msgs = append(msgs, msg) } // Ts is used as barrier time tick to ensure the message's time tick are given after the barrier time tick. - if err := streaming.WAL().Utility().AppendMessagesWithOption(ctx, streaming.AppendOption{ - BarrierTimeTick: ts, - }, msgs...).UnwrapFirstError(); err != nil { + resp := streaming.WAL().AppendMessages(ctx, msgs...) + if err := resp.UnwrapFirstError(); err != nil { return 0, err } - return ts, nil + return resp.MaxTimeTick(), nil } func (c *bgGarbageCollector) GcCollectionData(ctx context.Context, coll *model.Collection) (ddlTs Timestamp, err error) { diff --git a/internal/rootcoord/garbage_collector_test.go b/internal/rootcoord/garbage_collector_test.go index 43b9d34741..f49da0e00c 100644 --- a/internal/rootcoord/garbage_collector_test.go +++ b/internal/rootcoord/garbage_collector_test.go @@ -547,15 +547,11 @@ func TestGcPartitionData(t *testing.T) { defer streamingutil.UnsetStreamingServiceEnabled() wal := mock_streaming.NewMockWALAccesser(t) - u := mock_streaming.NewMockUtility(t) - u.EXPECT().AppendMessagesWithOption(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) - wal.EXPECT().Utility().Return(u) + wal.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) streaming.SetWALForTest(wal) tsoAllocator := mocktso.NewAllocator(t) - tsoAllocator.EXPECT().GenerateTSO(mock.Anything).Return(1000, nil) - - core := newTestCore(withTsoAllocator(tsoAllocator)) + core := newTestCore() gc := newBgGarbageCollector(core) core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator) diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 694baaea21..4661dc8179 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -50,6 +50,7 @@ import ( "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" tsoutil2 "github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/kv" @@ -716,10 +717,13 @@ func (c *Core) startInternal() error { } func (c *Core) startServerLoop() { - c.wg.Add(3) + c.wg.Add(2) go c.startTimeTickLoop() go c.tsLoop() - go c.chanTimeTick.startWatch(&c.wg) + if !streamingutil.IsStreamingServiceEnabled() { + c.wg.Add(1) + go c.chanTimeTick.startWatch(&c.wg) + } } // Start starts RootCoord. diff --git a/internal/rootcoord/step.go b/internal/rootcoord/step.go index 42c47dbcaa..9b18df3a3c 100644 --- a/internal/rootcoord/step.go +++ b/internal/rootcoord/step.go @@ -412,7 +412,7 @@ func (s *broadcastCreatePartitionMsgStep) Execute(ctx context.Context) ([]nested } msgs = append(msgs, msg) } - if err := streaming.WAL().Utility().AppendMessagesWithOption(ctx, streaming.AppendOption{ + if err := streaming.WAL().AppendMessagesWithOption(ctx, streaming.AppendOption{ BarrierTimeTick: s.ts, }, msgs...).UnwrapFirstError(); err != nil { return nil, err diff --git a/internal/rootcoord/step_test.go b/internal/rootcoord/step_test.go index 958a946b1f..2e9e7de6a6 100644 --- a/internal/rootcoord/step_test.go +++ b/internal/rootcoord/step_test.go @@ -123,9 +123,7 @@ func TestSkip(t *testing.T) { func TestBroadcastCreatePartitionMsgStep(t *testing.T) { wal := mock_streaming.NewMockWALAccesser(t) - u := mock_streaming.NewMockUtility(t) - u.EXPECT().AppendMessagesWithOption(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) - wal.EXPECT().Utility().Return(u) + wal.EXPECT().AppendMessagesWithOption(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) streaming.SetWALForTest(wal) step := &broadcastCreatePartitionMsgStep{ diff --git a/internal/rootcoord/timeticksync.go b/internal/rootcoord/timeticksync.go index fee0c59ca1..22eed18acb 100644 --- a/internal/rootcoord/timeticksync.go +++ b/internal/rootcoord/timeticksync.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -162,7 +163,7 @@ func (t *timetickSync) sendToChannel() bool { } } - if len(idleSessionList) > 0 { + if len(idleSessionList) > 0 && !streamingutil.IsStreamingServiceEnabled() { // give warning every 2 second if not get ttMsg from source sessions if maxCnt%10 == 0 { log.Warn("session idle for long time", zap.Any("idle list", idleSessionList), @@ -319,6 +320,9 @@ func (t *timetickSync) startWatch(wg *sync.WaitGroup) { // SendTimeTickToChannel send each channel's min timetick to msg stream func (t *timetickSync) sendTimeTickToChannel(chanNames []string, ts typeutil.Timestamp) error { + if streamingutil.IsStreamingServiceEnabled() { + return nil + } func() { sub := tsoutil.SubByNow(ts) for _, chanName := range chanNames { diff --git a/internal/streamingnode/server/builder.go b/internal/streamingnode/server/builder.go index ee4b262f94..ab95ee4ad9 100644 --- a/internal/streamingnode/server/builder.go +++ b/internal/streamingnode/server/builder.go @@ -38,6 +38,12 @@ func (b *ServerBuilder) WithETCD(e *clientv3.Client) *ServerBuilder { return b } +// WithChunkManager sets chunk manager to the server builder. +func (b *ServerBuilder) WithChunkManager(cm storage.ChunkManager) *ServerBuilder { + b.chunkManager = cm + return b +} + // WithGRPCServer sets grpc server to the server builder. func (b *ServerBuilder) WithGRPCServer(svr *grpc.Server) *ServerBuilder { b.grpcServer = svr @@ -68,27 +74,21 @@ func (b *ServerBuilder) WithMetaKV(kv kv.MetaKv) *ServerBuilder { return b } -// WithChunkManager sets chunk manager to the server builder. -func (b *ServerBuilder) WithChunkManager(chunkManager storage.ChunkManager) *ServerBuilder { - b.chunkManager = chunkManager - return b -} - // Build builds a streaming node server. -func (s *ServerBuilder) Build() *Server { +func (b *ServerBuilder) Build() *Server { resource.Apply( - resource.OptETCD(s.etcdClient), - resource.OptRootCoordClient(s.rc), - resource.OptDataCoordClient(s.dc), - resource.OptStreamingNodeCatalog(streamingnode.NewCataLog(s.kv)), + resource.OptETCD(b.etcdClient), + resource.OptRootCoordClient(b.rc), + resource.OptDataCoordClient(b.dc), + resource.OptStreamingNodeCatalog(streamingnode.NewCataLog(b.kv)), ) resource.Apply( - resource.OptFlusher(flusherimpl.NewFlusher(s.chunkManager)), + resource.OptFlusher(flusherimpl.NewFlusher(b.chunkManager)), ) resource.Done() return &Server{ - session: s.session, - grpcServer: s.grpcServer, + session: b.session, + grpcServer: b.grpcServer, componentStateService: componentutil.NewComponentStateService(typeutil.StreamingNodeRole), } } diff --git a/internal/streamingnode/server/flusher/flusherimpl/channel_task.go b/internal/streamingnode/server/flusher/flusherimpl/channel_task.go new file mode 100644 index 0000000000..3de3f12b23 --- /dev/null +++ b/internal/streamingnode/server/flusher/flusherimpl/channel_task.go @@ -0,0 +1,139 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flusherimpl + +import ( + "context" + "sync" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/flushcommon/pipeline" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + adaptor2 "github.com/milvus-io/milvus/internal/streamingnode/server/wal/adaptor" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" + "github.com/milvus-io/milvus/pkg/streaming/util/options" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type TaskState int + +const ( + Pending TaskState = iota + Cancel + Done +) + +type ChannelTask interface { + Run() error + Cancel() +} + +type channelTask struct { + mu sync.Mutex + state TaskState + f *flusherImpl + vchannel string + wal wal.WAL +} + +func NewChannelTask(f *flusherImpl, vchannel string, wal wal.WAL) ChannelTask { + return &channelTask{ + state: Pending, + f: f, + vchannel: vchannel, + wal: wal, + } +} + +func (c *channelTask) Run() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.state == Cancel { + return nil + } + if c.f.fgMgr.HasFlowgraph(c.vchannel) { + return nil + } + log.Info("start to build pipeline", zap.String("vchannel", c.vchannel)) + + // Get recovery info from datacoord. + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + resp, err := resource.Resource().DataCoordClient(). + GetChannelRecoveryInfo(ctx, &datapb.GetChannelRecoveryInfoRequest{Vchannel: c.vchannel}) + if err = merr.CheckRPCCall(resp, err); err != nil { + return err + } + + // Convert common.MessageID to message.messageID. + messageID := adaptor.MustGetMessageIDFromMQWrapperIDBytes(c.wal.WALName(), resp.GetInfo().GetSeekPosition().GetMsgID()) + + // Create scanner. + policy := options.DeliverPolicyStartFrom(messageID) + handler := adaptor2.NewMsgPackAdaptorHandler() + ro := wal.ReadOption{ + DeliverPolicy: policy, + MessageFilter: []options.DeliverFilter{ + options.DeliverFilterVChannel(c.vchannel), + }, + MesasgeHandler: handler, + } + scanner, err := c.wal.Read(ctx, ro) + if err != nil { + return err + } + + // Build and add pipeline. + ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, c.f.pipelineParams, + &datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan()) + if err != nil { + return err + } + ds.Start() + c.f.fgMgr.AddFlowgraph(ds) + c.f.scanners.Insert(c.vchannel, scanner) + c.state = Done + + log.Info("build pipeline done", zap.String("vchannel", c.vchannel)) + return nil +} + +func (c *channelTask) Cancel() { + c.mu.Lock() + defer c.mu.Unlock() + switch c.state { + case Pending: + c.state = Cancel + case Cancel: + return + case Done: + if scanner, ok := c.f.scanners.GetAndRemove(c.vchannel); ok { + err := scanner.Close() + if err != nil { + log.Warn("scanner error", zap.String("vchannel", c.vchannel), zap.Error(err)) + } + } + c.f.fgMgr.RemoveFlowgraph(c.vchannel) + c.f.wbMgr.RemoveChannel(c.vchannel) + log.Info("flusher unregister vchannel done", zap.String("vchannel", c.vchannel)) + } +} diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go index 432849273d..b22feeea83 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go @@ -23,40 +23,39 @@ import ( "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/flushcommon/broker" "github.com/milvus-io/milvus/internal/flushcommon/pipeline" "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" "github.com/milvus-io/milvus/internal/flushcommon/util" "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" - "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" - adaptor2 "github.com/milvus-io/milvus/internal/streamingnode/server/wal/adaptor" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" - "github.com/milvus-io/milvus/pkg/streaming/util/options" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var tickDuration = 3 * time.Second - var _ flusher.Flusher = (*flusherImpl)(nil) type flusherImpl struct { + broker broker.Broker fgMgr pipeline.FlowgraphManager syncMgr syncmgr.SyncManager wbMgr writebuffer.BufferManager cpUpdater *util.ChannelCheckpointUpdater - tasks *typeutil.ConcurrentMap[string, wal.WAL] // unwatched vchannels + tasks *typeutil.ConcurrentMap[string, ChannelTask] scanners *typeutil.ConcurrentMap[string, wal.Scanner] // watched scanners - stopOnce sync.Once - stopChan chan struct{} + notifyCh chan struct{} + stopChan lifetime.SafeChan + stopWg sync.WaitGroup pipelineParams *util.PipelineParams } @@ -68,14 +67,15 @@ func NewFlusher(chunkManager storage.ChunkManager) flusher.Flusher { func newFlusherWithParam(params *util.PipelineParams) flusher.Flusher { fgMgr := pipeline.NewFlowgraphManager() return &flusherImpl{ + broker: params.Broker, fgMgr: fgMgr, syncMgr: params.SyncMgr, wbMgr: params.WriteBufferManager, cpUpdater: params.CheckpointUpdater, - tasks: typeutil.NewConcurrentMap[string, wal.WAL](), + tasks: typeutil.NewConcurrentMap[string, ChannelTask](), scanners: typeutil.NewConcurrentMap[string, wal.Scanner](), - stopOnce: sync.Once{}, - stopChan: make(chan struct{}), + notifyCh: make(chan struct{}, 1), + stopChan: lifetime.NewSafeChan(), pipelineParams: params, } } @@ -90,26 +90,39 @@ func (f *flusherImpl) RegisterPChannel(pchannel string, wal wal.WAL) error { return err } for _, collectionInfo := range resp.GetCollections() { - f.tasks.Insert(collectionInfo.GetVchannel(), wal) + f.RegisterVChannel(collectionInfo.GetVchannel(), wal) } return nil } +func (f *flusherImpl) RegisterVChannel(vchannel string, wal wal.WAL) { + if f.scanners.Contain(vchannel) { + return + } + f.tasks.GetOrInsert(vchannel, NewChannelTask(f, vchannel, wal)) + f.notify() + log.Info("flusher register vchannel done", zap.String("vchannel", vchannel)) +} + func (f *flusherImpl) UnregisterPChannel(pchannel string) { - f.scanners.Range(func(vchannel string, scanner wal.Scanner) bool { - if funcutil.ToPhysicalChannel(vchannel) != pchannel { - return true + f.tasks.Range(func(vchannel string, task ChannelTask) bool { + if funcutil.ToPhysicalChannel(vchannel) == pchannel { + f.UnregisterVChannel(vchannel) + } + return true + }) + f.scanners.Range(func(vchannel string, scanner wal.Scanner) bool { + if funcutil.ToPhysicalChannel(vchannel) == pchannel { + f.UnregisterVChannel(vchannel) } - f.UnregisterVChannel(vchannel) return true }) } -func (f *flusherImpl) RegisterVChannel(vchannel string, wal wal.WAL) { - f.tasks.Insert(vchannel, wal) -} - func (f *flusherImpl) UnregisterVChannel(vchannel string) { + if task, ok := f.tasks.Get(vchannel); ok { + task.Cancel() + } if scanner, ok := f.scanners.GetAndRemove(vchannel); ok { err := scanner.Close() if err != nil { @@ -118,96 +131,61 @@ func (f *flusherImpl) UnregisterVChannel(vchannel string) { } f.fgMgr.RemoveFlowgraph(vchannel) f.wbMgr.RemoveChannel(vchannel) + log.Info("flusher unregister vchannel done", zap.String("vchannel", vchannel)) +} + +func (f *flusherImpl) notify() { + select { + case f.notifyCh <- struct{}{}: + default: + } } func (f *flusherImpl) Start() { + f.stopWg.Add(1) f.wbMgr.Start() go f.cpUpdater.Start() go func() { - ticker := time.NewTicker(tickDuration) - defer ticker.Stop() + defer f.stopWg.Done() for { select { - case <-f.stopChan: - log.Info("flusher stopped") + case <-f.stopChan.CloseCh(): + log.Info("flusher exited") return - case <-ticker.C: - f.tasks.Range(func(vchannel string, wal wal.WAL) bool { - err := f.buildPipeline(vchannel, wal) - if err != nil { - log.Warn("build pipeline failed", zap.String("vchannel", vchannel), zap.Error(err)) - return true - } - log.Info("build pipeline done", zap.String("vchannel", vchannel)) - f.tasks.Remove(vchannel) + case <-f.notifyCh: + futures := make([]*conc.Future[any], 0) + f.tasks.Range(func(vchannel string, task ChannelTask) bool { + future := GetExecPool().Submit(func() (any, error) { + err := task.Run() + if err != nil { + log.Warn("build pipeline failed", zap.String("vchannel", vchannel), zap.Error(err)) + // Notify to trigger retry. + f.notify() + return nil, err + } + f.tasks.Remove(vchannel) + return nil, nil + }) + futures = append(futures, future) return true }) + _ = conc.AwaitAll(futures...) } } }() } func (f *flusherImpl) Stop() { - f.stopOnce.Do(func() { - close(f.stopChan) - f.scanners.Range(func(vchannel string, scanner wal.Scanner) bool { - err := scanner.Close() - if err != nil { - log.Warn("scanner error", zap.String("vchannel", vchannel), zap.Error(err)) - } - return true - }) - f.fgMgr.ClearFlowgraphs() - f.wbMgr.Stop() - f.cpUpdater.Close() + f.stopChan.Close() + f.stopWg.Wait() + f.scanners.Range(func(vchannel string, scanner wal.Scanner) bool { + err := scanner.Close() + if err != nil { + log.Warn("scanner error", zap.String("vchannel", vchannel), zap.Error(err)) + } + return true }) -} - -func (f *flusherImpl) buildPipeline(vchannel string, w wal.WAL) error { - if f.fgMgr.HasFlowgraph(vchannel) { - return nil - } - log.Info("start to build pipeline", zap.String("vchannel", vchannel)) - - // Get recovery info from datacoord. - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - resp, err := resource.Resource().DataCoordClient(). - GetChannelRecoveryInfo(ctx, &datapb.GetChannelRecoveryInfoRequest{Vchannel: vchannel}) - if err = merr.CheckRPCCall(resp, err); err != nil { - return err - } - - // Convert common.MessageID to message.messageID. - mqWrapperID, err := adaptor.DeserializeToMQWrapperID(resp.GetInfo().GetSeekPosition().GetMsgID(), w.WALName()) - if err != nil { - return err - } - messageID := adaptor.MustGetMessageIDFromMQWrapperID(mqWrapperID) - - // Create scanner. - policy := options.DeliverPolicyStartFrom(messageID) - handler := adaptor2.NewMsgPackAdaptorHandler() - ro := wal.ReadOption{ - DeliverPolicy: policy, - MessageFilter: []options.DeliverFilter{ - options.DeliverFilterVChannel(vchannel), - }, - MesasgeHandler: handler, - } - scanner, err := w.Read(ctx, ro) - if err != nil { - return err - } - - // Build and add pipeline. - ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, f.pipelineParams, - &datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan()) - if err != nil { - return err - } - ds.Start() - f.fgMgr.AddFlowgraph(ds) - f.scanners.Insert(vchannel, scanner) - return nil + f.fgMgr.ClearFlowgraphs() + f.wbMgr.Stop() + f.cpUpdater.Close() } diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go index c4d6b18715..f5fb24ba8b 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type FlusherSuite struct { @@ -58,7 +59,6 @@ type FlusherSuite struct { func (s *FlusherSuite) SetupSuite() { paramtable.Init() - tickDuration = 10 * time.Millisecond s.pchannel = "by-dev-rootcoord-dml_0" s.vchannels = []string{ @@ -106,26 +106,27 @@ func (s *FlusherSuite) SetupSuite() { } func (s *FlusherSuite) SetupTest() { - handlers := make([]wal.MessageHandler, 0, len(s.vchannels)) + handlers := typeutil.NewConcurrentSet[wal.MessageHandler]() scanner := mock_wal.NewMockScanner(s.T()) w := mock_wal.NewMockWAL(s.T()) - w.EXPECT().WALName().Return("rocksmq") + w.EXPECT().WALName().Return("rocksmq").Maybe() w.EXPECT().Read(mock.Anything, mock.Anything).RunAndReturn( func(ctx context.Context, option wal.ReadOption) (wal.Scanner, error) { - handlers = append(handlers, option.MesasgeHandler) + handlers.Insert(option.MesasgeHandler) return scanner, nil - }) + }).Maybe() once := sync.Once{} scanner.EXPECT().Close().RunAndReturn(func() error { once.Do(func() { - for _, handler := range handlers { - handler.Close() - } + handlers.Range(func(h wal.MessageHandler) bool { + h.Close() + return true + }) }) return nil - }) + }).Maybe() s.wal = w m := mocks.NewChunkManager(s.T()) @@ -164,6 +165,7 @@ func (s *FlusherSuite) TestFlusher_RegisterPChannel() { s.flusher.UnregisterPChannel(s.pchannel) s.Equal(0, s.flusher.(*flusherImpl).fgMgr.GetFlowgraphCount()) s.Equal(0, s.flusher.(*flusherImpl).scanners.Len()) + s.Equal(0, s.flusher.(*flusherImpl).tasks.Len()) } func (s *FlusherSuite) TestFlusher_RegisterVChannel() { @@ -181,6 +183,35 @@ func (s *FlusherSuite) TestFlusher_RegisterVChannel() { } s.Equal(0, s.flusher.(*flusherImpl).fgMgr.GetFlowgraphCount()) s.Equal(0, s.flusher.(*flusherImpl).scanners.Len()) + s.Equal(0, s.flusher.(*flusherImpl).tasks.Len()) +} + +func (s *FlusherSuite) TestFlusher_Concurrency() { + wg := &sync.WaitGroup{} + for i := 0; i < 10; i++ { + for _, vchannel := range s.vchannels { + wg.Add(1) + go func(vchannel string) { + s.flusher.RegisterVChannel(vchannel, s.wal) + wg.Done() + }(vchannel) + } + for _, vchannel := range s.vchannels { + wg.Add(1) + go func(vchannel string) { + s.flusher.UnregisterVChannel(vchannel) + wg.Done() + }(vchannel) + } + } + wg.Wait() + + for _, vchannel := range s.vchannels { + s.flusher.UnregisterVChannel(vchannel) + } + s.Equal(0, s.flusher.(*flusherImpl).fgMgr.GetFlowgraphCount()) + s.Equal(0, s.flusher.(*flusherImpl).scanners.Len()) + s.Equal(0, s.flusher.(*flusherImpl).tasks.Len()) } func TestFlusherSuite(t *testing.T) { diff --git a/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl.go b/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl.go index ae74cacd13..7993495acb 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl.go @@ -19,18 +19,36 @@ package flusherimpl import ( "context" - "go.uber.org/zap" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" - "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) -// TODO: func(vchannel string, msg FlushMsg) -func flushMsgHandlerImpl(wbMgr writebuffer.BufferManager) func(vchannel string, segmentIDs []int64) { - return func(vchannel string, segmentIDs []int64) { - err := wbMgr.SealSegments(context.Background(), vchannel, segmentIDs) - if err != nil { - log.Warn("failed to seal segments", zap.Error(err)) - } +func newFlushMsgHandler(wbMgr writebuffer.BufferManager) *flushMsgHandlerImpl { + return &flushMsgHandlerImpl{ + wbMgr: wbMgr, } } + +type flushMsgHandlerImpl struct { + wbMgr writebuffer.BufferManager +} + +func (impl *flushMsgHandlerImpl) HandleFlush(vchannel string, flushMsg message.ImmutableFlushMessageV2) error { + body, err := flushMsg.Body() + if err != nil { + return errors.Wrap(err, "failed to get flush message body") + } + if err := impl.wbMgr.SealSegments(context.Background(), vchannel, body.GetSegmentId()); err != nil { + return errors.Wrap(err, "failed to seal segments") + } + return nil +} + +func (impl *flushMsgHandlerImpl) HandleManualFlush(vchannel string, flushMsg message.ImmutableManualFlushMessageV2) error { + if err := impl.wbMgr.FlushChannel(context.Background(), vchannel, flushMsg.Header().GetFlushTs()); err != nil { + return errors.Wrap(err, "failed to flush channel") + } + return nil +} diff --git a/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl_test.go b/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl_test.go index cce193ad28..45e62b50ba 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flushmsg_handler_impl_test.go @@ -20,23 +20,76 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) -func TestFlushMsgHandler(t *testing.T) { +func TestFlushMsgHandler_HandleFlush(t *testing.T) { + vchannel := "ch-0" + // test failed wbMgr := writebuffer.NewMockBufferManager(t) wbMgr.EXPECT().SealSegments(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock err")) - fn := flushMsgHandlerImpl(wbMgr) - fn("ch-0", []int64{1, 2, 3}) + msg, err := message.NewFlushMessageBuilderV2(). + WithVChannel(vchannel). + WithHeader(&message.FlushMessageHeader{}). + WithBody(&message.FlushMessageBody{ + CollectionId: 0, + SegmentId: []int64{1, 2, 3}, + }). + BuildMutable() + assert.NoError(t, err) + + handler := newFlushMsgHandler(wbMgr) + msgID := mock_message.NewMockMessageID(t) + im, err := message.AsImmutableFlushMessageV2(msg.IntoImmutableMessage(msgID)) + assert.NoError(t, err) + err = handler.HandleFlush(vchannel, im) + assert.Error(t, err) // test normal wbMgr = writebuffer.NewMockBufferManager(t) wbMgr.EXPECT().SealSegments(mock.Anything, mock.Anything, mock.Anything).Return(nil) - fn = flushMsgHandlerImpl(wbMgr) - fn("ch-0", []int64{1, 2, 3}) + handler = newFlushMsgHandler(wbMgr) + err = handler.HandleFlush(vchannel, im) + assert.NoError(t, err) +} + +func TestFlushMsgHandler_HandleManualFlush(t *testing.T) { + vchannel := "ch-0" + + // test failed + wbMgr := writebuffer.NewMockBufferManager(t) + wbMgr.EXPECT().FlushChannel(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock err")) + + msg, err := message.NewManualFlushMessageBuilderV2(). + WithVChannel(vchannel). + WithHeader(&message.ManualFlushMessageHeader{ + CollectionId: 0, + FlushTs: 1000, + }). + WithBody(&message.ManualFlushMessageBody{}). + BuildMutable() + assert.NoError(t, err) + + handler := newFlushMsgHandler(wbMgr) + msgID := mock_message.NewMockMessageID(t) + im, err := message.AsImmutableManualFlushMessageV2(msg.IntoImmutableMessage(msgID)) + assert.NoError(t, err) + err = handler.HandleManualFlush(vchannel, im) + assert.Error(t, err) + + // test normal + wbMgr = writebuffer.NewMockBufferManager(t) + wbMgr.EXPECT().FlushChannel(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + handler = newFlushMsgHandler(wbMgr) + err = handler.HandleManualFlush(vchannel, im) + assert.NoError(t, err) } diff --git a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go b/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go index 086f924efc..d4c327083d 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go +++ b/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go @@ -46,6 +46,6 @@ func getPipelineParams(chunkManager storage.ChunkManager) *util.PipelineParams { WriteBufferManager: wbMgr, CheckpointUpdater: cpUpdater, Allocator: idalloc.NewMAllocator(rsc.IDAllocator()), - FlushMsgHandler: flushMsgHandlerImpl(wbMgr), + FlushMsgHandler: newFlushMsgHandler(wbMgr), } } diff --git a/internal/streamingnode/server/flusher/flusherimpl/pool.go b/internal/streamingnode/server/flusher/flusherimpl/pool.go new file mode 100644 index 0000000000..fcf527da2d --- /dev/null +++ b/internal/streamingnode/server/flusher/flusherimpl/pool.go @@ -0,0 +1,40 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flusherimpl + +import ( + "sync" + + "github.com/milvus-io/milvus/pkg/util/conc" +) + +var ( + execPool *conc.Pool[any] + execPoolInitOnce sync.Once +) + +func initExecPool() { + execPool = conc.NewPool[any]( + 128, + conc.WithPreAlloc(true), + ) +} + +func GetExecPool() *conc.Pool[any] { + execPoolInitOnce.Do(initExecPool) + return execPool +} diff --git a/internal/streamingnode/server/flusher/flushmsg_handler.go b/internal/streamingnode/server/flusher/flushmsg_handler.go index 00a9b1f42f..5df71680e0 100644 --- a/internal/streamingnode/server/flusher/flushmsg_handler.go +++ b/internal/streamingnode/server/flusher/flushmsg_handler.go @@ -16,6 +16,10 @@ package flusher -// TODO: type FlushMsgHandler = func(vchannel string, msg FlushMsg) +import "github.com/milvus-io/milvus/pkg/streaming/util/message" -type FlushMsgHandler = func(vchannel string, segmentIDs []int64) +type FlushMsgHandler interface { + HandleFlush(vchannel string, flushMsg message.ImmutableFlushMessageV2) error + + HandleManualFlush(vchannel string, flushMsg message.ImmutableManualFlushMessageV2) error +} diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index e86d490a22..23ff631605 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -9,7 +9,6 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" - sinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" "github.com/milvus-io/milvus/internal/types" @@ -75,14 +74,12 @@ func Apply(opts ...optResourceInit) { // Done finish all initialization of resources. func Done() { r.segmentAssignStatsManager = stats.NewStatsManager() - r.segmentSealedInspector = sinspector.NewSealedInspector(r.segmentAssignStatsManager.SealNotifier()) r.timeTickInspector = tinspector.NewTimeTickSyncInspector() assertNotNil(r.TSOAllocator()) assertNotNil(r.RootCoordClient()) assertNotNil(r.DataCoordClient()) assertNotNil(r.StreamingNodeCatalog()) assertNotNil(r.SegmentAssignStatsManager()) - assertNotNil(r.SegmentSealedInspector()) assertNotNil(r.TimeTickInspector()) } @@ -103,7 +100,6 @@ type resourceImpl struct { dataCoordClient types.DataCoordClient streamingNodeCatalog metastore.StreamingNodeCataLog segmentAssignStatsManager *stats.StatsManager - segmentSealedInspector sinspector.SealOperationInspector timeTickInspector tinspector.TimeTickSyncInspector } @@ -152,11 +148,6 @@ func (r *resourceImpl) SegmentAssignStatsManager() *stats.StatsManager { return r.segmentAssignStatsManager } -// SegmentSealedInspector returns the segment sealed inspector. -func (r *resourceImpl) SegmentSealedInspector() sinspector.SealOperationInspector { - return r.segmentSealedInspector -} - func (r *resourceImpl) TimeTickInspector() tinspector.TimeTickSyncInspector { return r.timeTickInspector } diff --git a/internal/streamingnode/server/resource/test_utility.go b/internal/streamingnode/server/resource/test_utility.go index da3b220404..bad9e0f4bf 100644 --- a/internal/streamingnode/server/resource/test_utility.go +++ b/internal/streamingnode/server/resource/test_utility.go @@ -7,7 +7,6 @@ import ( "testing" "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" - sinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" tinspector "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" ) @@ -27,6 +26,5 @@ func InitForTest(t *testing.T, opts ...optResourceInit) { r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } r.segmentAssignStatsManager = stats.NewStatsManager() - r.segmentSealedInspector = sinspector.NewSealedInspector(r.segmentAssignStatsManager.SealNotifier()) r.timeTickInspector = tinspector.NewTimeTickSyncInspector() } diff --git a/internal/streamingnode/server/server.go b/internal/streamingnode/server/server.go index cfae4ddbc8..01977259bf 100644 --- a/internal/streamingnode/server/server.go +++ b/internal/streamingnode/server/server.go @@ -7,6 +7,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/service" "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" "github.com/milvus-io/milvus/internal/util/componentutil" @@ -48,7 +49,8 @@ func (s *Server) Init(ctx context.Context) (err error) { // Start starts the streamingnode server. func (s *Server) Start() { - // Just do nothing now. + resource.Resource().Flusher().Start() + log.Info("flusher started") } // Stop stops the streamingnode server. @@ -58,6 +60,9 @@ func (s *Server) Stop() { log.Info("close wal manager...") s.walManager.Close() log.Info("streamingnode server stopped") + log.Info("stopping flusher...") + resource.Resource().Flusher().Stop() + log.Info("flusher stopped") } // Health returns the health status of the streamingnode server. diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go index 992bcdd595..49291cfc29 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go @@ -167,24 +167,37 @@ func (w *walAdaptorImpl) Available() <-chan struct{} { // Close overrides Scanner Close function. func (w *walAdaptorImpl) Close() { + logger := log.With(zap.Any("channel", w.Channel()), zap.String("processing", "WALClose")) + logger.Info("wal begin to close, start graceful close...") // graceful close the interceptors before wal closing. w.interceptorBuildResult.GracefulCloseFunc() + logger.Info("wal graceful close done, wait for operation to be finished...") + // begin to close the wal. w.lifetime.SetState(lifetime.Stopped) w.lifetime.Wait() w.lifetime.Close() + logger.Info("wal begin to close scanners...") + // close all wal instances. w.scanners.Range(func(id int64, s wal.Scanner) bool { s.Close() - log.Info("close scanner by wal extend", zap.Int64("id", id), zap.Any("channel", w.Channel())) + log.Info("close scanner by wal adaptor", zap.Int64("id", id), zap.Any("channel", w.Channel())) return true }) + + logger.Info("scanner close done, close inner wal...") w.inner.Close() + + logger.Info("scanner close done, close interceptors...") w.interceptorBuildResult.Close() w.appendExecutionPool.Free() + + logger.Info("call wal cleanup function...") w.cleanup() + logger.Info("wal closed") } type interceptorBuildResult struct { diff --git a/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go b/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go index 91619334e1..32ee6b8299 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go +++ b/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go @@ -4,13 +4,15 @@ import ( "context" "time" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) const ( - defaultSealAllInterval = 10 * time.Second + defaultSealAllInterval = 10 * time.Second + defaultMustSealInterval = 200 * time.Millisecond ) // NewSealedInspector creates a new seal inspector. @@ -82,6 +84,9 @@ func (s *sealOperationInspectorImpl) background() { sealAllTicker := time.NewTicker(defaultSealAllInterval) defer sealAllTicker.Stop() + mustSealTicker := time.NewTicker(defaultMustSealInterval) + defer mustSealTicker.Stop() + var backoffCh <-chan time.Time for { if s.shouldEnableBackoff() { @@ -112,6 +117,11 @@ func (s *sealOperationInspectorImpl) background() { pm.TryToSealSegments(s.taskNotifier.Context()) return true }) + case <-mustSealTicker.C: + segmentBelongs := resource.Resource().SegmentAssignStatsManager().SealByTotalGrowingSegmentsSize() + if pm, ok := s.managers.Get(segmentBelongs.PChannel); ok { + pm.MustSealSegments(s.taskNotifier.Context(), segmentBelongs) + } } } } diff --git a/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector.go b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector.go index 3fef273441..caa1e4155f 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector.go +++ b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector.go @@ -2,11 +2,25 @@ package inspector import ( "context" + "sync" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/pkg/streaming/util/types" ) +var ( + segmentSealedInspector SealOperationInspector + initOnce sync.Once +) + +func GetSegmentSealedInspector() SealOperationInspector { + initOnce.Do(func() { + segmentSealedInspector = NewSealedInspector(resource.Resource().SegmentAssignStatsManager().SealNotifier()) + }) + return segmentSealedInspector +} + // SealOperationInspector is the inspector to check if a segment should be sealed or not. type SealOperationInspector interface { // TriggerSealWaited triggers the seal waited segment. @@ -36,6 +50,9 @@ type SealOperator interface { // Return false if there's some segment wait for seal but not sealed. TryToSealWaitedSegment(ctx context.Context) + // MustSealSegments seals the given segments and waiting seal segments. + MustSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) + // IsNoWaitSeal returns whether there's no segment wait for seal. IsNoWaitSeal() bool } diff --git a/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector_test.go b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector_test.go index 5e2894a216..d795cc19cb 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector_test.go @@ -10,11 +10,16 @@ import ( "go.uber.org/atomic" "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestSealedInspector(t *testing.T) { + paramtable.Init() + resource.InitForTest(t) + notifier := stats.NewSealSignalNotifier() inspector := NewSealedInspector(notifier) @@ -52,6 +57,7 @@ func TestSealedInspector(t *testing.T) { VChannel: "vv1", CollectionID: 12, PartitionID: 1, + SegmentID: 2, }) time.Sleep(5 * time.Millisecond) } diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go index e102d114d3..99879b169d 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go @@ -5,11 +5,13 @@ import ( "sync" "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/policy" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -97,6 +99,22 @@ func (m *partitionSegmentManager) CollectShouldBeSealed() []*segmentAllocManager return m.collectShouldBeSealedWithPolicy(m.hitSealPolicy) } +// CollectionMustSealed seals the specified segment. +func (m *partitionSegmentManager) CollectionMustSealed(segmentID int64) *segmentAllocManager { + m.mu.Lock() + defer m.mu.Unlock() + + var target *segmentAllocManager + m.segments = lo.Filter(m.segments, func(segment *segmentAllocManager, _ int) bool { + if segment.inner.GetSegmentId() == segmentID { + target = segment + return false + } + return true + }) + return target +} + // collectShouldBeSealedWithPolicy collects all segments that should be sealed by policy. func (m *partitionSegmentManager) collectShouldBeSealedWithPolicy(predicates func(segmentMeta *segmentAllocManager) bool) []*segmentAllocManager { shouldBeSealedSegments := make([]*segmentAllocManager, 0, len(m.segments)) @@ -267,5 +285,5 @@ func (m *partitionSegmentManager) assignSegment(ctx context.Context, req *Assign if inserted, ack := newGrowingSegment.AllocRows(ctx, req); inserted { return &AssignSegmentResult{SegmentID: newGrowingSegment.GetSegmentID(), Acknowledge: ack}, nil } - return nil, errors.Errorf("too large insert message, cannot hold in empty growing segment, stats: %+v", req.InsertMetrics) + return nil, status.NewUnrecoverableError("too large insert message, cannot hold in empty growing segment, stats: %+v", req.InsertMetrics) } diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go index 2cec243849..f66e6bab37 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go @@ -7,6 +7,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -151,7 +152,7 @@ func (m *partitionSegmentManagers) NewPartition(collectionID int64, partitionID func (m *partitionSegmentManagers) Get(collectionID int64, partitionID int64) (*partitionSegmentManager, error) { pm, ok := m.managers.Get(partitionID) if !ok { - return nil, errors.Errorf("partition %d in collection %d not found in segment assignment service", partitionID, collectionID) + return nil, status.NewUnrecoverableError("partition %d in collection %d not found in segment assignment service", partitionID, collectionID) } return pm, nil } diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go index adcf87370a..1b8e1bce87 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" @@ -113,7 +114,7 @@ func (m *PChannelSegmentAllocManager) RemoveCollection(ctx context.Context, coll m.helper.AsyncSeal(waitForSealed...) // trigger a seal operation in background rightnow. - resource.Resource().SegmentSealedInspector().TriggerSealWaited(ctx, m.pchannel.Name) + inspector.GetSegmentSealedInspector().TriggerSealWaited(ctx, m.pchannel.Name) // wait for all segment has been flushed. return m.helper.WaitUntilNoWaitSeal(ctx) @@ -132,7 +133,7 @@ func (m *PChannelSegmentAllocManager) RemovePartition(ctx context.Context, colle m.helper.AsyncSeal(waitForSealed...) // trigger a seal operation in background rightnow. - resource.Resource().SegmentSealedInspector().TriggerSealWaited(ctx, m.pchannel.Name) + inspector.GetSegmentSealedInspector().TriggerSealWaited(ctx, m.pchannel.Name) // wait for all segment has been flushed. return m.helper.WaitUntilNoWaitSeal(ctx) @@ -191,6 +192,20 @@ func (m *PChannelSegmentAllocManager) TryToSealSegments(ctx context.Context, inf m.helper.SealAllWait(ctx) } +func (m *PChannelSegmentAllocManager) MustSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) { + if err := m.lifetime.Add(lifetime.IsWorking); err != nil { + return + } + defer m.lifetime.Done() + + for _, info := range infos { + if pm, err := m.managers.Get(info.CollectionID, info.PartitionID); err == nil { + m.helper.AsyncSeal(pm.CollectionMustSealed(info.SegmentID)) + } + } + m.helper.SealAllWait(ctx) +} + // TryToSealWaitedSegment tries to seal the wait for sealing segment. func (m *PChannelSegmentAllocManager) TryToSealWaitedSegment(ctx context.Context) { if err := m.lifetime.Add(lifetime.IsWorking); err != nil { diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go index ae7dbe36f6..ab4b324447 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go @@ -17,6 +17,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" @@ -129,6 +130,7 @@ func TestSegmentAllocManager(t *testing.T) { VChannel: "v1", PartitionID: 2, PChannel: "v1", + SegmentID: 3, }) assert.True(t, m.IsNoWaitSeal()) @@ -195,7 +197,7 @@ func TestCreateAndDropCollection(t *testing.T) { m, err := RecoverPChannelSegmentAllocManager(context.Background(), types.PChannelInfo{Name: "v1"}, f) assert.NoError(t, err) assert.NotNil(t, m) - resource.Resource().SegmentSealedInspector().RegsiterPChannelManager(m) + inspector.GetSegmentSealedInspector().RegsiterPChannelManager(m) ctx := context.Background() diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go index 594648a661..afa81221ae 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go @@ -30,6 +30,7 @@ func newSegmentAllocManagerFromProto( resource.Resource().SegmentAssignStatsManager().RegisterNewGrowingSegment(stats.SegmentBelongs{ CollectionID: inner.GetCollectionId(), PartitionID: inner.GetPartitionId(), + SegmentID: inner.GetSegmentId(), PChannel: pchannel.Name, VChannel: inner.GetVchannel(), }, inner.GetSegmentId(), stat) @@ -253,6 +254,7 @@ func (m *mutableSegmentAssignmentMeta) Commit(ctx context.Context) error { resource.Resource().SegmentAssignStatsManager().RegisterNewGrowingSegment(stats.SegmentBelongs{ CollectionID: m.original.GetCollectionID(), PartitionID: m.original.GetPartitionID(), + SegmentID: m.original.GetSegmentID(), PChannel: m.original.pchannel.Name, VChannel: m.original.GetVChannel(), }, m.original.GetSegmentID(), stats.NewSegmentStatFromProto(m.modifiedCopy.Stat)) diff --git a/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go b/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go index 1d85a487bd..bb2c815208 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go @@ -5,12 +5,14 @@ import ( "time" "go.uber.org/zap" + "google.golang.org/protobuf/types/known/anypb" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/manager" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -48,6 +50,8 @@ func (impl *segmentInterceptor) DoAppend(ctx context.Context, msg message.Mutabl return impl.handleDropPartition(ctx, msg, appendOp) case message.MessageTypeInsert: return impl.handleInsertMessage(ctx, msg, appendOp) + case message.MessageTypeManualFlush: + return impl.handleManualFlushMessage(ctx, msg, appendOp) default: return appendOp(ctx, msg) } @@ -144,7 +148,7 @@ func (impl *segmentInterceptor) handleInsertMessage(ctx context.Context, msg mes TxnSession: txn.GetTxnSessionFromContext(ctx), }) if err != nil { - return nil, status.NewInner("segment assignment failure with error: %s", err.Error()) + return nil, err } // once the segment assignment is done, we need to ack the result, // if other partitions failed to assign segment or wal write failure, @@ -162,12 +166,43 @@ func (impl *segmentInterceptor) handleInsertMessage(ctx context.Context, msg mes return appendOp(ctx, msg) } +// handleManualFlushMessage handles the manual flush message. +func (impl *segmentInterceptor) handleManualFlushMessage(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) { + maunalFlushMsg, err := message.AsMutableManualFlushMessageV2(msg) + if err != nil { + return nil, err + } + header := maunalFlushMsg.Header() + segmentIDs, err := impl.assignManager.Get().SealAllSegmentsAndFenceUntil(ctx, header.GetCollectionId(), header.GetFlushTs()) + if err != nil { + return nil, status.NewInner("segment seal failure with error: %s", err.Error()) + } + + // create extra response for manual flush message. + extraResponse, err := anypb.New(&message.ManualFlushExtraResponse{ + SegmentIds: segmentIDs, + }) + if err != nil { + return nil, status.NewInner("create extra response failed with error: %s", err.Error()) + } + + // send the manual flush message. + msgID, err := appendOp(ctx, msg) + if err != nil { + return nil, err + } + + utility.AttachAppendResultExtra(ctx, extraResponse) + return msgID, nil +} + // Close closes the segment interceptor. func (impl *segmentInterceptor) Close() { + impl.cancel() assignManager := impl.assignManager.Get() if assignManager != nil { // unregister the pchannels - resource.Resource().SegmentSealedInspector().UnregisterPChannelManager(assignManager) + inspector.GetSegmentSealedInspector().UnregisterPChannelManager(assignManager) assignManager.Close(context.Background()) } } @@ -199,7 +234,7 @@ func (impl *segmentInterceptor) recoverPChannelManager(param interceptors.Interc } // register the manager into inspector, to do the seal asynchronously - resource.Resource().SegmentSealedInspector().RegsiterPChannelManager(pm) + inspector.GetSegmentSealedInspector().RegsiterPChannelManager(pm) impl.assignManager.Set(pm) impl.logger.Info("recover PChannel Assignment Manager success") return diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go index 3fdd80bc9e..864d1221a8 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go @@ -3,6 +3,11 @@ package stats import ( "fmt" "sync" + + "github.com/pingcap/log" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/util/paramtable" ) // StatsManager is the manager of stats. @@ -23,6 +28,7 @@ type SegmentBelongs struct { VChannel string CollectionID int64 PartitionID int64 + SegmentID int64 } // NewStatsManager creates a new stats manager. @@ -153,6 +159,34 @@ func (m *StatsManager) UnregisterSealedSegment(segmentID int64) *SegmentStats { return stats } +// SealByTotalGrowingSegmentsSize seals the largest growing segment +// if the total size of growing segments in ANY vchannel exceeds the threshold. +func (m *StatsManager) SealByTotalGrowingSegmentsSize() SegmentBelongs { + m.mu.Lock() + defer m.mu.Unlock() + + for vchannel, metrics := range m.vchannelStats { + threshold := paramtable.Get().DataCoordCfg.GrowingSegmentsMemSizeInMB.GetAsUint64() * 1024 * 1024 + if metrics.BinarySize >= threshold { + var ( + largestSegment int64 = 0 + largestSegmentSize uint64 = 0 + ) + for segmentID, stats := range m.segmentStats { + if stats.Insert.BinarySize > largestSegmentSize { + largestSegmentSize = stats.Insert.BinarySize + largestSegment = segmentID + } + } + log.Info("seal by total growing segments size", zap.String("vchannel", vchannel), + zap.Uint64("vchannelGrowingSize", metrics.BinarySize), zap.Uint64("sealThreshold", threshold), + zap.Int64("sealSegment", largestSegment), zap.Uint64("sealSegmentSize", largestSegmentSize)) + return m.segmentIndex[largestSegment] + } + } + return SegmentBelongs{} +} + // InsertOpeatationMetrics is the metrics of insert operation. type InsertMetrics struct { Rows uint64 diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go index 0a01abbb7c..47d53cc6e5 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go @@ -10,32 +10,32 @@ import ( func TestStatsManager(t *testing.T) { m := NewStatsManager() - m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2}, 3, createSegmentStats(100, 100, 300)) + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2, SegmentID: 3}, 3, createSegmentStats(100, 100, 300)) assert.Len(t, m.segmentStats, 1) assert.Len(t, m.vchannelStats, 1) assert.Len(t, m.pchannelStats, 1) assert.Len(t, m.segmentIndex, 1) - m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 3}, 4, createSegmentStats(100, 100, 300)) + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 3, SegmentID: 4}, 4, createSegmentStats(100, 100, 300)) assert.Len(t, m.segmentStats, 2) assert.Len(t, m.segmentIndex, 2) assert.Len(t, m.vchannelStats, 1) assert.Len(t, m.pchannelStats, 1) - m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel2", CollectionID: 2, PartitionID: 4}, 5, createSegmentStats(100, 100, 300)) + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel2", CollectionID: 2, PartitionID: 4, SegmentID: 5}, 5, createSegmentStats(100, 100, 300)) assert.Len(t, m.segmentStats, 3) assert.Len(t, m.segmentIndex, 3) assert.Len(t, m.vchannelStats, 2) assert.Len(t, m.pchannelStats, 1) - m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel2", VChannel: "vchannel3", CollectionID: 2, PartitionID: 5}, 6, createSegmentStats(100, 100, 300)) + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel2", VChannel: "vchannel3", CollectionID: 2, PartitionID: 5, SegmentID: 6}, 6, createSegmentStats(100, 100, 300)) assert.Len(t, m.segmentStats, 4) assert.Len(t, m.segmentIndex, 4) assert.Len(t, m.vchannelStats, 3) assert.Len(t, m.pchannelStats, 2) assert.Panics(t, func() { - m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2}, 3, createSegmentStats(100, 100, 300)) + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2, SegmentID: 3}, 3, createSegmentStats(100, 100, 300)) }) shouldBlock(t, m.SealNotifier().WaitChan()) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go index 1a9dc27cfe..f1062ecc0b 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go @@ -15,8 +15,7 @@ func TestDetail(t *testing.T) { assert.Panics(t, func() { newAckDetail(0, mock_message.NewMockMessageID(t)) }) - msgID := mock_message.NewMockMessageID(t) - msgID.EXPECT().EQ(msgID).Return(true) + msgID := walimplstest.NewTestMessageID(1) ackDetail := newAckDetail(1, msgID) assert.Equal(t, uint64(1), ackDetail.BeginTimestamp) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/inspector/notifier.go b/internal/streamingnode/server/wal/interceptors/timetick/inspector/notifier.go index 66c1b66ff5..1ef94e3dce 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/inspector/notifier.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/inspector/notifier.go @@ -101,10 +101,17 @@ func (l *TimeTickNotifier) OnlyUpdateTs(timetick uint64) { // Or if the time tick is less than the last time tick, return channel. func (l *TimeTickNotifier) WatchAtMessageID(messageID message.MessageID, ts uint64) <-chan struct{} { l.cond.L.Lock() - if l.info.IsZero() || !l.info.MessageID.EQ(messageID) { + // If incoming messageID is less than the producer messageID, + // the consumer can read the new greater messageID from wal, + // so the watch operation is not necessary. + if l.info.IsZero() || messageID.LT(l.info.MessageID) { l.cond.L.Unlock() return nil } + + // messageID may be greater than MessageID in notifier. + // because consuming operation is fast than produce operation. + // so doing a listening here. if ts < l.info.TimeTick { ch := make(chan struct{}) close(ch) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go index 4786b745f0..1acb87c22a 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go @@ -5,6 +5,7 @@ import ( "time" "github.com/cockroachdb/errors" + "go.uber.org/zap" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" @@ -90,7 +91,7 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message txnSession.AddNewMessageFail() } // perform keepalive for the transaction session if append success. - txnSession.AddNewMessageAndKeepalive(msg.TimeTick()) + txnSession.AddNewMessageDoneAndKeepalive(msg.TimeTick()) }() } } @@ -106,8 +107,15 @@ func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message // GracefulClose implements InterceptorWithGracefulClose. func (impl *timeTickAppendInterceptor) GracefulClose() { - log.Warn("timeTickAppendInterceptor is closing") - impl.txnManager.GracefulClose() + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + logger := log.With(zap.Any("pchannel", impl.operator.pchannel)) + logger.Info("timeTickAppendInterceptor is closing, try to perform a txn manager graceful shutdown") + if err := impl.txnManager.GracefulClose(ctx); err != nil { + logger.Warn("timeTickAppendInterceptor is closed", zap.Error(err)) + return + } + logger.Info("txnManager of timeTickAppendInterceptor is graceful closed") } // Close implements AppendInterceptor. diff --git a/internal/streamingnode/server/wal/interceptors/txn/session.go b/internal/streamingnode/server/wal/interceptors/txn/session.go index 3f5d385742..41b0e39a6b 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/session.go +++ b/internal/streamingnode/server/wal/interceptors/txn/session.go @@ -58,12 +58,13 @@ func (s *TxnSession) BeginRollback() { // AddNewMessage adds a new message to the session. func (s *TxnSession) AddNewMessage(ctx context.Context, timetick uint64) error { + s.mu.Lock() + defer s.mu.Unlock() + // if the txn is expired, return error. if err := s.checkIfExpired(timetick); err != nil { return err } - s.mu.Lock() - defer s.mu.Unlock() if s.state != message.TxnStateInFlight { return status.NewInvalidTransactionState("AddNewMessage", message.TxnStateInFlight, s.state) @@ -72,9 +73,9 @@ func (s *TxnSession) AddNewMessage(ctx context.Context, timetick uint64) error { return nil } -// AddNewMessageAndKeepalive decreases the in flight count of the session and keepalive the session. +// AddNewMessageDoneAndKeepalive decreases the in flight count of the session and keepalive the session. // notify the committedWait channel if the in flight count is 0 and committed waited. -func (s *TxnSession) AddNewMessageAndKeepalive(timetick uint64) { +func (s *TxnSession) AddNewMessageDoneAndKeepalive(timetick uint64) { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/streamingnode/server/wal/interceptors/txn/session_test.go b/internal/streamingnode/server/wal/interceptors/txn/session_test.go index 30e067c661..68a2bd84ab 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/session_test.go +++ b/internal/streamingnode/server/wal/interceptors/txn/session_test.go @@ -64,7 +64,7 @@ func TestSession(t *testing.T) { assert.NoError(t, err) err = session.AddNewMessage(ctx, 0) assert.NoError(t, err) - session.AddNewMessageAndKeepalive(0) + session.AddNewMessageDoneAndKeepalive(0) // Test Commit. err = session.RequestCommitAndWait(ctx, 0) @@ -147,7 +147,7 @@ func TestManager(t *testing.T) { closed := make(chan struct{}) go func() { - m.GracefulClose() + m.GracefulClose(context.Background()) close(closed) }() diff --git a/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go index 7ada33350b..6bdb427b2b 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go +++ b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go @@ -5,8 +5,11 @@ import ( "sync" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/util/lifetime" ) @@ -93,7 +96,7 @@ func (m *TxnManager) GetSessionOfTxn(id message.TxnID) (*TxnSession, error) { } // GracefulClose waits for all transactions to be cleaned up. -func (m *TxnManager) GracefulClose() { +func (m *TxnManager) GracefulClose(ctx context.Context) error { m.mu.Lock() if m.closed == nil { m.closed = lifetime.NewSafeChan() @@ -101,7 +104,13 @@ func (m *TxnManager) GracefulClose() { m.closed.Close() } } + log.Info("there's still txn session in txn manager, waiting for them to be consumed", zap.Int("session count", len(m.sessions))) m.mu.Unlock() - <-m.closed.CloseCh() + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.closed.CloseCh(): + return nil + } } diff --git a/internal/util/pipeline/stream_pipeline.go b/internal/util/pipeline/stream_pipeline.go index 9485129f89..2765e11492 100644 --- a/internal/util/pipeline/stream_pipeline.go +++ b/internal/util/pipeline/stream_pipeline.go @@ -24,10 +24,15 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/msgdispatcher" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" + "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/util/tsoutil" ) @@ -39,6 +44,7 @@ type StreamPipeline interface { type streamPipeline struct { pipeline *pipeline input <-chan *msgstream.MsgPack + scanner streaming.Scanner dispatcher msgdispatcher.Client startOnce sync.Once vChannel string @@ -70,6 +76,30 @@ func (p *streamPipeline) ConsumeMsgStream(position *msgpb.MsgPosition) error { return ErrNilPosition } + if streamingutil.IsStreamingServiceEnabled() { + startFrom := adaptor.MustGetMessageIDFromMQWrapperIDBytes("pulsar", position.GetMsgID()) + log.Info( + "stream pipeline seeks from position with scanner", + zap.String("channel", position.GetChannelName()), + zap.Any("startFromMessageID", startFrom), + zap.Uint64("timestamp", position.GetTimestamp()), + ) + handler := adaptor.NewMsgPackAdaptorHandler() + p.scanner = streaming.WAL().Read(context.Background(), streaming.ReadOption{ + VChannel: position.GetChannelName(), + DeliverPolicy: options.DeliverPolicyStartFrom(startFrom), + DeliverFilters: []options.DeliverFilter{ + // only consume messages with timestamp >= position timestamp + options.DeliverFilterTimeTickGTE(position.GetTimestamp()), + // only consume insert and delete messages + options.DeliverFilterMessageType(message.MessageTypeInsert, message.MessageTypeDelete), + }, + MessageHandler: handler, + }) + p.input = handler.Chan() + return nil + } + start := time.Now() p.input, err = p.dispatcher.Register(context.TODO(), p.vChannel, position, common.SubscriptionPositionUnknown) if err != nil { @@ -105,6 +135,9 @@ func (p *streamPipeline) Close() { p.closeOnce.Do(func() { close(p.closeCh) p.closeWg.Wait() + if p.scanner != nil { + p.scanner.Close() + } p.dispatcher.Deregister(p.vChannel) p.pipeline.Close() }) diff --git a/internal/util/streamingutil/checker.go b/internal/util/streamingutil/env.go similarity index 59% rename from internal/util/streamingutil/checker.go rename to internal/util/streamingutil/env.go index 6572797a0c..8c81c685fc 100644 --- a/internal/util/streamingutil/checker.go +++ b/internal/util/streamingutil/env.go @@ -10,22 +10,6 @@ func IsStreamingServiceEnabled() bool { return os.Getenv(MilvusStreamingServiceEnabled) == "1" } -// SetStreamingServiceEnabled sets the env that indicates whether the streaming service is enabled. -func SetStreamingServiceEnabled() { - err := os.Setenv(MilvusStreamingServiceEnabled, "1") - if err != nil { - panic(err) - } -} - -// UnsetStreamingServiceEnabled unsets the env that indicates whether the streaming service is enabled. -func UnsetStreamingServiceEnabled() { - err := os.Setenv(MilvusStreamingServiceEnabled, "0") - if err != nil { - panic(err) - } -} - // MustEnableStreamingService panics if the streaming service is not enabled. func MustEnableStreamingService() { if !IsStreamingServiceEnabled() { diff --git a/internal/util/streamingutil/status/streaming_error.go b/internal/util/streamingutil/status/streaming_error.go index 1fa176fb49..8e1fc9d155 100644 --- a/internal/util/streamingutil/status/streaming_error.go +++ b/internal/util/streamingutil/status/streaming_error.go @@ -43,6 +43,12 @@ func (e *StreamingError) IsSkippedOperation() bool { e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM } +// IsUnrecoverable returns true if the error is unrecoverable. +// Stop resuming retry and report to user. +func (e *StreamingError) IsUnrecoverable() bool { + return e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNRECOVERABLE || e.IsTxnUnavilable() +} + // IsTxnUnavilable returns true if the transaction is unavailable. func (e *StreamingError) IsTxnUnavilable() bool { return e.Code == streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED || @@ -105,6 +111,11 @@ func NewInvalidTransactionState(operation string, expectState message.TxnState, return New(streamingpb.StreamingCode_STREAMING_CODE_INVALID_TRANSACTION_STATE, "invalid transaction state for operation %s, expect %s, current %s", operation, expectState, currentState) } +// NewUnrecoverableError creates a new StreamingError with code STREAMING_CODE_UNRECOVERABLE. +func NewUnrecoverableError(format string, args ...interface{}) *StreamingError { + return New(streamingpb.StreamingCode_STREAMING_CODE_UNRECOVERABLE, format, args...) +} + // New creates a new StreamingError with the given code and cause. func New(code streamingpb.StreamingCode, format string, args ...interface{}) *StreamingError { if len(args) == 0 { diff --git a/internal/util/streamingutil/test_env.go b/internal/util/streamingutil/test_env.go new file mode 100644 index 0000000000..d0c5a5237e --- /dev/null +++ b/internal/util/streamingutil/test_env.go @@ -0,0 +1,22 @@ +//go:build test +// +build test + +package streamingutil + +import "os" + +// SetStreamingServiceEnabled set the env that indicates whether the streaming service is enabled. +func SetStreamingServiceEnabled() { + err := os.Setenv(MilvusStreamingServiceEnabled, "1") + if err != nil { + panic(err) + } +} + +// UnsetStreamingServiceEnabled unsets the env that indicates whether the streaming service is enabled. +func UnsetStreamingServiceEnabled() { + err := os.Setenv(MilvusStreamingServiceEnabled, "0") + if err != nil { + panic(err) + } +} diff --git a/pkg/go.mod b/pkg/go.mod index 65bd79ee10..1171b67820 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -13,7 +13,7 @@ require ( github.com/expr-lang/expr v1.15.7 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.17.7 - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240815123953-6dab6fcd6454 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240820032106-b34be93a2271 github.com/nats-io/nats-server/v2 v2.10.12 github.com/nats-io/nats.go v1.34.1 github.com/panjf2000/ants/v2 v2.7.2 diff --git a/pkg/go.sum b/pkg/go.sum index 86c4eb3dad..45c207a36e 100644 --- a/pkg/go.sum +++ b/pkg/go.sum @@ -494,8 +494,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240815123953-6dab6fcd6454 h1:JmZCYjMPpiE4ksZw0AUxXWkDY7wwA4fhS+SO1N211Vw= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240815123953-6dab6fcd6454/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240820032106-b34be93a2271 h1:YUWBgtRHmvkxMPTfOrY3FIq0K5XHw02Z18z7cyaMH04= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240820032106-b34be93a2271/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/pulsar-client-go v0.6.10 h1:eqpJjU+/QX0iIhEo3nhOqMNXL+TyInAs1IAHZCrCM/A= github.com/milvus-io/pulsar-client-go v0.6.10/go.mod h1:lQqCkgwDF8YFYjKA+zOheTk1tev2B+bKj5j7+nm8M1w= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= diff --git a/pkg/mocks/streaming/util/mock_message/mock_MessageID.go b/pkg/mocks/streaming/util/mock_message/mock_MessageID.go index fca86396a7..92222732d9 100644 --- a/pkg/mocks/streaming/util/mock_message/mock_MessageID.go +++ b/pkg/mocks/streaming/util/mock_message/mock_MessageID.go @@ -187,6 +187,47 @@ func (_c *MockMessageID_Marshal_Call) RunAndReturn(run func() string) *MockMessa return _c } +// String provides a mock function with given fields: +func (_m *MockMessageID) String() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockMessageID_String_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'String' +type MockMessageID_String_Call struct { + *mock.Call +} + +// String is a helper method to define mock.On call +func (_e *MockMessageID_Expecter) String() *MockMessageID_String_Call { + return &MockMessageID_String_Call{Call: _e.mock.On("String")} +} + +func (_c *MockMessageID_String_Call) Run(run func()) *MockMessageID_String_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMessageID_String_Call) Return(_a0 string) *MockMessageID_String_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMessageID_String_Call) RunAndReturn(run func() string) *MockMessageID_String_Call { + _c.Call.Return(run) + return _c +} + // WALName provides a mock function with given fields: func (_m *MockMessageID) WALName() string { ret := _m.Called() diff --git a/pkg/streaming/proto/messages.proto b/pkg/streaming/proto/messages.proto index 7bd2201152..896435dfc1 100644 --- a/pkg/streaming/proto/messages.proto +++ b/pkg/streaming/proto/messages.proto @@ -33,6 +33,7 @@ enum MessageType { DropCollection = 6; CreatePartition = 7; DropPartition = 8; + ManualFlush = 9; // begin transaction message is only used for transaction, once a begin // transaction message is received, all messages combined with the // transaction message cannot be consumed until a CommitTxn message @@ -71,11 +72,14 @@ enum MessageType { // FlushMessageBody is the body of flush message. message FlushMessageBody { - int64 collection_id = - 1; // indicate which the collection that segment belong to. + // indicate which the collection that segment belong to. + int64 collection_id = 1; repeated int64 segment_id = 2; // indicate which segment to flush. } +// ManualFlushMessageBody is the body of manual flush message. +message ManualFlushMessageBody {} + // BeginTxnMessageBody is the body of begin transaction message. // Just do nothing now. message BeginTxnMessageBody {} @@ -134,6 +138,11 @@ message DeleteMessageHeader { // FlushMessageHeader just nothing. message FlushMessageHeader {} +message ManualFlushMessageHeader { + int64 collection_id = 1; + uint64 flush_ts = 2; +} + // CreateCollectionMessageHeader is the header of create collection message. message CreateCollectionMessageHeader { int64 collection_id = 1; @@ -179,6 +188,17 @@ message RollbackTxnMessageHeader {} // Just do nothing now. message TxnMessageHeader {} +/// +/// Message Extra Response +/// Used to add extra information when response to the client. +/// +/// + +// ManualFlushExtraResponse is the extra response of manual flush message. +message ManualFlushExtraResponse { + repeated int64 segment_ids = 1; +} + // TxnContext is the context of transaction. // It will be carried by every message in a transaction. message TxnContext { diff --git a/pkg/streaming/proto/streaming.proto b/pkg/streaming/proto/streaming.proto index 7930d5a3c3..7b623718a6 100644 --- a/pkg/streaming/proto/streaming.proto +++ b/pkg/streaming/proto/streaming.proto @@ -198,6 +198,7 @@ enum StreamingCode { STREAMING_CODE_INVAILD_ARGUMENT = 8; // invalid argument STREAMING_CODE_TRANSACTION_EXPIRED = 9; // transaction expired STREAMING_CODE_INVALID_TRANSACTION_STATE = 10; // invalid transaction state + STREAMING_CODE_UNRECOVERABLE = 11; // unrecoverable error STREAMING_CODE_UNKNOWN = 999; // unknown error } diff --git a/pkg/streaming/util/message/adaptor/handler.go b/pkg/streaming/util/message/adaptor/handler.go index efcd93569b..d7dc1c97d0 100644 --- a/pkg/streaming/util/message/adaptor/handler.go +++ b/pkg/streaming/util/message/adaptor/handler.go @@ -69,7 +69,7 @@ func (m *BaseMsgPackAdaptorHandler) GenerateMsgPack(msg message.ImmutableMessage } } m.Pendings = append(m.Pendings, msg) - case message.VersionV1: + case message.VersionV1, message.VersionV2: if len(m.Pendings) != 0 { // all previous message should be vOld. m.addMsgPackIntoPending(m.Pendings...) m.Pendings = nil diff --git a/pkg/streaming/util/message/adaptor/handler_test.go b/pkg/streaming/util/message/adaptor/handler_test.go new file mode 100644 index 0000000000..84194d274f --- /dev/null +++ b/pkg/streaming/util/message/adaptor/handler_test.go @@ -0,0 +1,145 @@ +package adaptor + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" +) + +func TestMsgPackAdaptorHandler(t *testing.T) { + id := rmq.NewRmqID(1) + + h := NewMsgPackAdaptorHandler() + insertMsg := message.CreateTestInsertMessage(t, 1, 100, 10, id) + insertImmutableMessage := insertMsg.IntoImmutableMessage(id) + ch := make(chan *msgstream.MsgPack, 1) + go func() { + for msgPack := range h.Chan() { + ch <- msgPack + } + close(ch) + }() + h.Handle(insertImmutableMessage) + msgPack := <-ch + + assert.Equal(t, uint64(10), msgPack.BeginTs) + assert.Equal(t, uint64(10), msgPack.EndTs) + for _, tsMsg := range msgPack.Msgs { + assert.Equal(t, uint64(10), tsMsg.BeginTs()) + assert.Equal(t, uint64(10), tsMsg.EndTs()) + for _, ts := range tsMsg.(*msgstream.InsertMsg).Timestamps { + assert.Equal(t, uint64(10), ts) + } + } + + deleteMsg, err := message.NewDeleteMessageBuilderV1(). + WithVChannel("vchan1"). + WithHeader(&message.DeleteMessageHeader{ + CollectionId: 1, + }). + WithBody(&msgpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + }, + CollectionID: 1, + PartitionID: 1, + Timestamps: []uint64{10}, + }). + BuildMutable() + assert.NoError(t, err) + + deleteImmutableMsg := deleteMsg. + WithTimeTick(11). + WithLastConfirmedUseMessageID(). + IntoImmutableMessage(id) + + h.Handle(deleteImmutableMsg) + msgPack = <-ch + assert.Equal(t, uint64(11), msgPack.BeginTs) + assert.Equal(t, uint64(11), msgPack.EndTs) + for _, tsMsg := range msgPack.Msgs { + assert.Equal(t, uint64(11), tsMsg.BeginTs()) + assert.Equal(t, uint64(11), tsMsg.EndTs()) + for _, ts := range tsMsg.(*msgstream.DeleteMsg).Timestamps { + assert.Equal(t, uint64(11), ts) + } + } + + // Create a txn message + msg, err := message.NewBeginTxnMessageBuilderV2(). + WithVChannel("vchan1"). + WithHeader(&message.BeginTxnMessageHeader{ + KeepaliveMilliseconds: 1000, + }). + WithBody(&message.BeginTxnMessageBody{}). + BuildMutable() + assert.NoError(t, err) + assert.NotNil(t, msg) + + txnCtx := message.TxnContext{ + TxnID: 1, + Keepalive: time.Second, + } + + beginImmutableMsg, err := message.AsImmutableBeginTxnMessageV2(msg.WithTimeTick(9). + WithTxnContext(txnCtx). + WithLastConfirmedUseMessageID(). + IntoImmutableMessage(rmq.NewRmqID(2))) + assert.NoError(t, err) + + msg, err = message.NewCommitTxnMessageBuilderV2(). + WithVChannel("vchan1"). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + BuildMutable() + assert.NoError(t, err) + + commitImmutableMsg, err := message.AsImmutableCommitTxnMessageV2(msg.WithTimeTick(12). + WithTxnContext(txnCtx). + WithTxnContext(message.TxnContext{}). + WithLastConfirmedUseMessageID(). + IntoImmutableMessage(rmq.NewRmqID(3))) + assert.NoError(t, err) + + txn, err := message.NewImmutableTxnMessageBuilder(beginImmutableMsg). + Add(insertMsg.WithTxnContext(txnCtx).IntoImmutableMessage(id)). + Add(deleteMsg.WithTxnContext(txnCtx).IntoImmutableMessage(id)). + Build(commitImmutableMsg) + assert.NoError(t, err) + + h.Handle(txn) + msgPack = <-ch + + assert.Equal(t, uint64(12), msgPack.BeginTs) + assert.Equal(t, uint64(12), msgPack.EndTs) + + // Create flush message + msg, err = message.NewFlushMessageBuilderV2(). + WithVChannel("vchan1"). + WithHeader(&message.FlushMessageHeader{}). + WithBody(&message.FlushMessageBody{}). + BuildMutable() + assert.NoError(t, err) + + flushMsg := msg. + WithTimeTick(13). + WithLastConfirmedUseMessageID(). + IntoImmutableMessage(rmq.NewRmqID(4)) + + h.Handle(flushMsg) + + msgPack = <-ch + + assert.Equal(t, uint64(13), msgPack.BeginTs) + assert.Equal(t, uint64(13), msgPack.EndTs) + + h.Close() + <-ch +} diff --git a/pkg/streaming/util/message/adaptor/message.go b/pkg/streaming/util/message/adaptor/message.go index 9dc9eb4f2d..f712364f10 100644 --- a/pkg/streaming/util/message/adaptor/message.go +++ b/pkg/streaming/util/message/adaptor/message.go @@ -3,6 +3,7 @@ package adaptor import ( "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -22,18 +23,18 @@ func NewMsgPackFromMessage(msgs ...message.ImmutableMessage) (*msgstream.MsgPack var finalErr error for _, msg := range msgs { - var tsMsg msgstream.TsMsg - var err error - switch msg.Version() { - case message.VersionOld: - tsMsg, err = fromMessageToTsMsgVOld(msg) - case message.VersionV1: - tsMsg, err = fromMessageToTsMsgV1(msg) - case message.VersionV2: - tsMsg, err = fromMessageToTsMsgV2(msg) - default: - panic("unsupported message version") + // Parse a transaction message into multiple tsMsgs. + if msg.MessageType() == message.MessageTypeTxn { + tsMsgs, err := parseTxnMsg(msg) + if err != nil { + finalErr = errors.CombineErrors(finalErr, errors.Wrapf(err, "Failed to convert txn message to msgpack, %v", msg.MessageID())) + continue + } + allTsMsgs = append(allTsMsgs, tsMsgs...) + continue } + + tsMsg, err := parseSingleMsg(msg) if err != nil { finalErr = errors.CombineErrors(finalErr, errors.Wrapf(err, "Failed to convert message to msgpack, %v", msg.MessageID())) continue @@ -49,15 +50,64 @@ func NewMsgPackFromMessage(msgs ...message.ImmutableMessage) (*msgstream.MsgPack // 1. So use the first tsMsgs's Position can read all messages which timetick is greater or equal than the first tsMsgs's BeginTs. // In other words, from the StartPositions, you can read the full msgPack. // 2. Use the last tsMsgs's Position as the EndPosition, you can read all messages following the msgPack. + beginTs := allTsMsgs[0].BeginTs() + endTs := allTsMsgs[len(allTsMsgs)-1].EndTs() + startPosition := allTsMsgs[0].Position() + endPosition := allTsMsgs[len(allTsMsgs)-1].Position() + // filter the TimeTick message. + tsMsgs := make([]msgstream.TsMsg, 0, len(allTsMsgs)) + for _, msg := range allTsMsgs { + if msg.Type() == commonpb.MsgType_TimeTick { + continue + } + tsMsgs = append(tsMsgs, msg) + } return &msgstream.MsgPack{ - BeginTs: allTsMsgs[0].BeginTs(), - EndTs: allTsMsgs[len(allTsMsgs)-1].EndTs(), - Msgs: allTsMsgs, - StartPositions: []*msgstream.MsgPosition{allTsMsgs[0].Position()}, - EndPositions: []*msgstream.MsgPosition{allTsMsgs[len(allTsMsgs)-1].Position()}, + BeginTs: beginTs, + EndTs: endTs, + Msgs: tsMsgs, + StartPositions: []*msgstream.MsgPosition{startPosition}, + EndPositions: []*msgstream.MsgPosition{endPosition}, }, finalErr } +// parseTxnMsg converts a txn message to ts message list. +func parseTxnMsg(msg message.ImmutableMessage) ([]msgstream.TsMsg, error) { + txnMsg := message.AsImmutableTxnMessage(msg) + if txnMsg == nil { + panic("unreachable code, message must be a txn message") + } + + tsMsgs := make([]msgstream.TsMsg, 0, txnMsg.Size()) + err := txnMsg.RangeOver(func(im message.ImmutableMessage) error { + var tsMsg msgstream.TsMsg + tsMsg, err := parseSingleMsg(im) + if err != nil { + return err + } + tsMsgs = append(tsMsgs, tsMsg) + return nil + }) + if err != nil { + return nil, err + } + return tsMsgs, nil +} + +// parseSingleMsg converts message to ts message. +func parseSingleMsg(msg message.ImmutableMessage) (msgstream.TsMsg, error) { + switch msg.Version() { + case message.VersionOld: + return fromMessageToTsMsgVOld(msg) + case message.VersionV1: + return fromMessageToTsMsgV1(msg) + case message.VersionV2: + return fromMessageToTsMsgV2(msg) + default: + panic("unsupported message version") + } +} + func fromMessageToTsMsgVOld(msg message.ImmutableMessage) (msgstream.TsMsg, error) { panic("Not implemented") } @@ -87,6 +137,8 @@ func fromMessageToTsMsgV2(msg message.ImmutableMessage) (msgstream.TsMsg, error) switch msg.MessageType() { case message.MessageTypeFlush: tsMsg, err = NewFlushMessageBody(msg) + case message.MessageTypeManualFlush: + tsMsg, err = NewManualFlushMessageBody(msg) default: panic("unsupported message type") } @@ -115,6 +167,12 @@ func recoverMessageFromHeader(tsMsg msgstream.TsMsg, msg message.ImmutableMessag // insertMsg has multiple partition and segment assignment is done by insert message header. // so recover insert message from header before send it. return recoverInsertMsgFromHeader(tsMsg.(*msgstream.InsertMsg), insertMessage.Header(), msg.TimeTick()) + case message.MessageTypeDelete: + deleteMessage, err := message.AsImmutableDeleteMessageV1(msg) + if err != nil { + return nil, errors.Wrap(err, "Failed to convert message to delete message") + } + return recoverDeleteMsgFromHeader(tsMsg.(*msgstream.DeleteMsg), deleteMessage.Header(), msg.TimeTick()) default: return tsMsg, nil } @@ -145,5 +203,18 @@ func recoverInsertMsgFromHeader(insertMsg *msgstream.InsertMsg, header *message. timestamps[i] = timetick } insertMsg.Timestamps = timestamps + insertMsg.Base.Timestamp = timetick return insertMsg, nil } + +func recoverDeleteMsgFromHeader(deleteMsg *msgstream.DeleteMsg, header *message.DeleteMessageHeader, timetick uint64) (msgstream.TsMsg, error) { + if deleteMsg.GetCollectionID() != header.GetCollectionId() { + panic("unreachable code, collection id is not equal") + } + timestamps := make([]uint64, len(deleteMsg.Timestamps)) + for i := 0; i < len(timestamps); i++ { + timestamps[i] = timetick + } + deleteMsg.Timestamps = timestamps + return deleteMsg, nil +} diff --git a/pkg/streaming/util/message/adaptor/message_id_test.go b/pkg/streaming/util/message/adaptor/message_id_test.go index e5216e064b..6b0944e8ce 100644 --- a/pkg/streaming/util/message/adaptor/message_id_test.go +++ b/pkg/streaming/util/message/adaptor/message_id_test.go @@ -10,7 +10,7 @@ import ( "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" ) -func TestIDCoversion(t *testing.T) { +func TestIDConvension(t *testing.T) { id := MustGetMessageIDFromMQWrapperID(MustGetMQWrapperIDFromMessage(rmq.NewRmqID(1))) assert.True(t, id.EQ(rmq.NewRmqID(1))) diff --git a/pkg/streaming/util/message/adaptor/message_type.go b/pkg/streaming/util/message/adaptor/message_type.go index 54d9f245ac..ea3ab24389 100644 --- a/pkg/streaming/util/message/adaptor/message_type.go +++ b/pkg/streaming/util/message/adaptor/message_type.go @@ -9,7 +9,8 @@ var messageTypeToCommonpbMsgType = map[message.MessageType]commonpb.MsgType{ message.MessageTypeTimeTick: commonpb.MsgType_TimeTick, message.MessageTypeInsert: commonpb.MsgType_Insert, message.MessageTypeDelete: commonpb.MsgType_Delete, - message.MessageTypeFlush: commonpb.MsgType_Flush, + message.MessageTypeFlush: commonpb.MsgType_FlushSegment, + message.MessageTypeManualFlush: commonpb.MsgType_ManualFlush, message.MessageTypeCreateCollection: commonpb.MsgType_CreateCollection, message.MessageTypeDropCollection: commonpb.MsgType_DropCollection, message.MessageTypeCreatePartition: commonpb.MsgType_CreatePartition, diff --git a/pkg/streaming/util/message/adaptor/ts_msg_newer.go b/pkg/streaming/util/message/adaptor/ts_msg_newer.go index 52d1bda3cf..0fb558deb4 100644 --- a/pkg/streaming/util/message/adaptor/ts_msg_newer.go +++ b/pkg/streaming/util/message/adaptor/ts_msg_newer.go @@ -52,7 +52,7 @@ func (t *tsMsgImpl) SetTs(ts uint64) { type FlushMessageBody struct { *tsMsgImpl - *message.FlushMessageBody + FlushMessage message.ImmutableFlushMessageV2 } func NewFlushMessageBody(msg message.ImmutableMessage) (msgstream.TsMsg, error) { @@ -60,10 +60,6 @@ func NewFlushMessageBody(msg message.ImmutableMessage) (msgstream.TsMsg, error) if err != nil { return nil, err } - body, err := flushMsg.Body() - if err != nil { - return nil, err - } return &FlushMessageBody{ tsMsgImpl: &tsMsgImpl{ BaseMsg: msgstream.BaseMsg{ @@ -72,8 +68,32 @@ func NewFlushMessageBody(msg message.ImmutableMessage) (msgstream.TsMsg, error) }, ts: msg.TimeTick(), sz: msg.EstimateSize(), - msgType: commonpb.MsgType(msg.MessageType()), + msgType: MustGetCommonpbMsgTypeFromMessageType(msg.MessageType()), }, - FlushMessageBody: body, + FlushMessage: flushMsg, + }, nil +} + +type ManualFlushMessageBody struct { + *tsMsgImpl + ManualFlushMessage message.ImmutableManualFlushMessageV2 +} + +func NewManualFlushMessageBody(msg message.ImmutableMessage) (msgstream.TsMsg, error) { + flushMsg, err := message.AsImmutableManualFlushMessageV2(msg) + if err != nil { + return nil, err + } + return &ManualFlushMessageBody{ + tsMsgImpl: &tsMsgImpl{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: msg.TimeTick(), + EndTimestamp: msg.TimeTick(), + }, + ts: msg.TimeTick(), + sz: msg.EstimateSize(), + msgType: MustGetCommonpbMsgTypeFromMessageType(msg.MessageType()), + }, + ManualFlushMessage: flushMsg, }, nil } diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index 0eacc04555..0432cbb613 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -45,6 +45,7 @@ var ( NewCreatePartitionMessageBuilderV1 = createNewMessageBuilderV1[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest]() NewDropPartitionMessageBuilderV1 = createNewMessageBuilderV1[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest]() NewFlushMessageBuilderV2 = createNewMessageBuilderV2[*FlushMessageHeader, *FlushMessageBody]() + NewManualFlushMessageBuilderV2 = createNewMessageBuilderV2[*ManualFlushMessageHeader, *ManualFlushMessageBody]() NewBeginTxnMessageBuilderV2 = createNewMessageBuilderV2[*BeginTxnMessageHeader, *BeginTxnMessageBody]() NewCommitTxnMessageBuilderV2 = createNewMessageBuilderV2[*CommitTxnMessageHeader, *CommitTxnMessageBody]() NewRollbackTxnMessageBuilderV2 = createNewMessageBuilderV2[*RollbackTxnMessageHeader, *RollbackTxnMessageBody]() diff --git a/pkg/streaming/util/message/message_builder_test.go b/pkg/streaming/util/message/message_builder_test.go index 5c2a503392..cb798fe891 100644 --- a/pkg/streaming/util/message/message_builder_test.go +++ b/pkg/streaming/util/message/message_builder_test.go @@ -2,14 +2,12 @@ package message_test import ( "bytes" - "fmt" "testing" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" ) @@ -45,28 +43,18 @@ func TestMessage(t *testing.T) { assert.Equal(t, uint64(123), mutableMessage.TimeTick()) assert.Equal(t, uint64(456), mutableMessage.BarrierTimeTick()) - lcMsgID := mock_message.NewMockMessageID(t) - lcMsgID.EXPECT().Marshal().Return("lcMsgID") + lcMsgID := walimplstest.NewTestMessageID(1) mutableMessage.WithLastConfirmed(lcMsgID) v, ok = mutableMessage.Properties().Get("_lc") assert.True(t, ok) - assert.Equal(t, v, "lcMsgID") + assert.Equal(t, v, "1") v, ok = mutableMessage.Properties().Get("_vc") assert.True(t, ok) assert.Equal(t, "v1", v) assert.Equal(t, "v1", mutableMessage.VChannel()) - msgID := mock_message.NewMockMessageID(t) - msgID.EXPECT().EQ(msgID).Return(true) - msgID.EXPECT().WALName().Return("testMsgID") - message.RegisterMessageIDUnmsarshaler("testMsgID", func(data string) (message.MessageID, error) { - if data == "lcMsgID" { - return msgID, nil - } - panic(fmt.Sprintf("unexpected data: %s", data)) - }) - + msgID := walimplstest.NewTestMessageID(1) immutableMessage := message.NewImmutableMesasge(msgID, []byte("payload"), map[string]string{ @@ -74,7 +62,7 @@ func TestMessage(t *testing.T) { "_t": "1", "_tt": message.EncodeUint64(456), "_v": "1", - "_lc": "lcMsgID", + "_lc": "1", }) assert.True(t, immutableMessage.MessageID().EQ(msgID)) @@ -84,7 +72,7 @@ func TestMessage(t *testing.T) { assert.Equal(t, "value", v) assert.True(t, ok) assert.Equal(t, message.MessageTypeTimeTick, immutableMessage.MessageType()) - assert.Equal(t, 36, immutableMessage.EstimateSize()) + assert.Equal(t, 30, immutableMessage.EstimateSize()) assert.Equal(t, message.Version(1), immutableMessage.Version()) assert.Equal(t, uint64(456), immutableMessage.TimeTick()) assert.NotNil(t, immutableMessage.LastConfirmedMessageID()) diff --git a/pkg/streaming/util/message/message_id.go b/pkg/streaming/util/message/message_id.go index b1d9fa14f8..f6864506e9 100644 --- a/pkg/streaming/util/message/message_id.go +++ b/pkg/streaming/util/message/message_id.go @@ -49,4 +49,7 @@ type MessageID interface { // Marshal marshal the message id. Marshal() string + + // Convert into string for logging. + String() string } diff --git a/pkg/streaming/util/message/message_type.go b/pkg/streaming/util/message/message_type.go index 3f102b9447..a2a2d3369b 100644 --- a/pkg/streaming/util/message/message_type.go +++ b/pkg/streaming/util/message/message_type.go @@ -14,6 +14,7 @@ const ( MessageTypeInsert MessageType = MessageType(messagespb.MessageType_Insert) MessageTypeDelete MessageType = MessageType(messagespb.MessageType_Delete) MessageTypeFlush MessageType = MessageType(messagespb.MessageType_Flush) + MessageTypeManualFlush MessageType = MessageType(messagespb.MessageType_ManualFlush) MessageTypeCreateCollection MessageType = MessageType(messagespb.MessageType_CreateCollection) MessageTypeDropCollection MessageType = MessageType(messagespb.MessageType_DropCollection) MessageTypeCreatePartition MessageType = MessageType(messagespb.MessageType_CreatePartition) @@ -30,6 +31,7 @@ var messageTypeName = map[MessageType]string{ MessageTypeInsert: "INSERT", MessageTypeDelete: "DELETE", MessageTypeFlush: "FLUSH", + MessageTypeManualFlush: "MANUAL_FLUSH", MessageTypeCreateCollection: "CREATE_COLLECTION", MessageTypeDropCollection: "DROP_COLLECTION", MessageTypeCreatePartition: "CREATE_PARTITION", diff --git a/pkg/streaming/util/message/specialized_message.go b/pkg/streaming/util/message/specialized_message.go index 050ec53c38..9ee1892ee2 100644 --- a/pkg/streaming/util/message/specialized_message.go +++ b/pkg/streaming/util/message/specialized_message.go @@ -22,6 +22,7 @@ type ( CreatePartitionMessageHeader = messagespb.CreatePartitionMessageHeader DropPartitionMessageHeader = messagespb.DropPartitionMessageHeader FlushMessageHeader = messagespb.FlushMessageHeader + ManualFlushMessageHeader = messagespb.ManualFlushMessageHeader BeginTxnMessageHeader = messagespb.BeginTxnMessageHeader CommitTxnMessageHeader = messagespb.CommitTxnMessageHeader RollbackTxnMessageHeader = messagespb.RollbackTxnMessageHeader @@ -30,12 +31,17 @@ type ( type ( FlushMessageBody = messagespb.FlushMessageBody + ManualFlushMessageBody = messagespb.ManualFlushMessageBody BeginTxnMessageBody = messagespb.BeginTxnMessageBody CommitTxnMessageBody = messagespb.CommitTxnMessageBody RollbackTxnMessageBody = messagespb.RollbackTxnMessageBody TxnMessageBody = messagespb.TxnMessageBody ) +type ( + ManualFlushExtraResponse = messagespb.ManualFlushExtraResponse +) + // messageTypeMap maps the proto message type to the message type. var messageTypeMap = map[reflect.Type]MessageType{ reflect.TypeOf(&TimeTickMessageHeader{}): MessageTypeTimeTick, @@ -46,6 +52,7 @@ var messageTypeMap = map[reflect.Type]MessageType{ reflect.TypeOf(&CreatePartitionMessageHeader{}): MessageTypeCreatePartition, reflect.TypeOf(&DropPartitionMessageHeader{}): MessageTypeDropPartition, reflect.TypeOf(&FlushMessageHeader{}): MessageTypeFlush, + reflect.TypeOf(&ManualFlushMessageHeader{}): MessageTypeManualFlush, reflect.TypeOf(&BeginTxnMessageHeader{}): MessageTypeBeginTxn, reflect.TypeOf(&CommitTxnMessageHeader{}): MessageTypeCommitTxn, reflect.TypeOf(&RollbackTxnMessageHeader{}): MessageTypeRollbackTxn, @@ -83,6 +90,7 @@ type ( ImmutableCreatePartitionMessageV1 = specializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] ImmutableDropPartitionMessageV1 = specializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] ImmutableFlushMessageV2 = specializedImmutableMessage[*FlushMessageHeader, *FlushMessageBody] + ImmutableManualFlushMessageV2 = specializedImmutableMessage[*ManualFlushMessageHeader, *ManualFlushMessageBody] ImmutableBeginTxnMessageV2 = specializedImmutableMessage[*BeginTxnMessageHeader, *BeginTxnMessageBody] ImmutableCommitTxnMessageV2 = specializedImmutableMessage[*CommitTxnMessageHeader, *CommitTxnMessageBody] ImmutableRollbackTxnMessageV2 = specializedImmutableMessage[*RollbackTxnMessageHeader, *RollbackTxnMessageBody] @@ -98,6 +106,7 @@ var ( AsMutableCreatePartitionMessageV1 = asSpecializedMutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] AsMutableDropPartitionMessageV1 = asSpecializedMutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] AsMutableFlushMessageV2 = asSpecializedMutableMessage[*FlushMessageHeader, *FlushMessageBody] + AsMutableManualFlushMessageV2 = asSpecializedMutableMessage[*ManualFlushMessageHeader, *ManualFlushMessageBody] AsMutableBeginTxnMessageV2 = asSpecializedMutableMessage[*BeginTxnMessageHeader, *BeginTxnMessageBody] AsMutableCommitTxnMessageV2 = asSpecializedMutableMessage[*CommitTxnMessageHeader, *CommitTxnMessageBody] AsMutableRollbackTxnMessageV2 = asSpecializedMutableMessage[*RollbackTxnMessageHeader, *RollbackTxnMessageBody] @@ -110,6 +119,7 @@ var ( AsImmutableCreatePartitionMessageV1 = asSpecializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] AsImmutableDropPartitionMessageV1 = asSpecializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] AsImmutableFlushMessageV2 = asSpecializedImmutableMessage[*FlushMessageHeader, *FlushMessageBody] + AsImmutableManualFlushMessageV2 = asSpecializedImmutableMessage[*ManualFlushMessageHeader, *ManualFlushMessageBody] AsImmutableBeginTxnMessageV2 = asSpecializedImmutableMessage[*BeginTxnMessageHeader, *BeginTxnMessageBody] AsImmutableCommitTxnMessageV2 = asSpecializedImmutableMessage[*CommitTxnMessageHeader, *CommitTxnMessageBody] AsImmutableRollbackTxnMessageV2 = asSpecializedImmutableMessage[*RollbackTxnMessageHeader, *RollbackTxnMessageBody] diff --git a/pkg/streaming/walimpls/impls/pulsar/message_id.go b/pkg/streaming/walimpls/impls/pulsar/message_id.go index e2f3fa3bcd..a6e6313359 100644 --- a/pkg/streaming/walimpls/impls/pulsar/message_id.go +++ b/pkg/streaming/walimpls/impls/pulsar/message_id.go @@ -2,6 +2,7 @@ package pulsar import ( "encoding/base64" + "fmt" "github.com/apache/pulsar-client-go/pulsar" "github.com/cockroachdb/errors" @@ -84,3 +85,7 @@ func (id pulsarID) EQ(other message.MessageID) bool { func (id pulsarID) Marshal() string { return base64.StdEncoding.EncodeToString(id.Serialize()) } + +func (id pulsarID) String() string { + return fmt.Sprintf("%d/%d/%d", id.LedgerID(), id.EntryID(), id.BatchIdx()) +} diff --git a/pkg/streaming/walimpls/impls/rmq/message_id.go b/pkg/streaming/walimpls/impls/rmq/message_id.go index 6312cc0b3f..af548ad07d 100644 --- a/pkg/streaming/walimpls/impls/rmq/message_id.go +++ b/pkg/streaming/walimpls/impls/rmq/message_id.go @@ -1,6 +1,8 @@ package rmq import ( + "strconv" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -66,3 +68,7 @@ func (id rmqID) EQ(other message.MessageID) bool { func (id rmqID) Marshal() string { return message.EncodeInt64(int64(id)) } + +func (id rmqID) String() string { + return strconv.FormatInt(int64(id), 10) +} diff --git a/pkg/streaming/walimpls/impls/walimplstest/message_id.go b/pkg/streaming/walimpls/impls/walimplstest/message_id.go index afc8eb7ca0..16fd80768a 100644 --- a/pkg/streaming/walimpls/impls/walimplstest/message_id.go +++ b/pkg/streaming/walimpls/impls/walimplstest/message_id.go @@ -61,3 +61,7 @@ func (id testMessageID) EQ(other message.MessageID) bool { func (id testMessageID) Marshal() string { return strconv.FormatInt(int64(id), 10) } + +func (id testMessageID) String() string { + return strconv.FormatInt(int64(id), 10) +} diff --git a/tests/integration/channel_balance/channel_balance_test.go b/tests/integration/channel_balance/channel_balance_test.go index edb86ffd03..c734466748 100644 --- a/tests/integration/channel_balance/channel_balance_test.go +++ b/tests/integration/channel_balance/channel_balance_test.go @@ -14,12 +14,16 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/tests/integration" ) func TestChannelBalanceSuite(t *testing.T) { + if streamingutil.IsStreamingServiceEnabled() { + t.Skip("skip channel balance test in streaming service mode") + } suite.Run(t, new(ChannelBalanceSuite)) } diff --git a/tests/integration/compaction/clustering_compaction_test.go b/tests/integration/compaction/clustering_compaction_test.go index c8c65eaac7..24791ffe83 100644 --- a/tests/integration/compaction/clustering_compaction_test.go +++ b/tests/integration/compaction/clustering_compaction_test.go @@ -62,6 +62,9 @@ func (s *ClusteringCompactionSuite) TestClusteringCompaction() { paramtable.Get().Save(paramtable.Get().DataCoordCfg.SegmentMaxSize.Key, strconv.Itoa(1)) defer paramtable.Get().Reset(paramtable.Get().DataCoordCfg.SegmentMaxSize.Key) + paramtable.Get().Save(paramtable.Get().PulsarCfg.MaxMessageSize.Key, strconv.Itoa(500*1024)) + defer paramtable.Get().Reset(paramtable.Get().PulsarCfg.MaxMessageSize.Key) + paramtable.Get().Save(paramtable.Get().DataNodeCfg.ClusteringCompactionWorkerPoolSize.Key, strconv.Itoa(8)) defer paramtable.Get().Reset(paramtable.Get().DataNodeCfg.ClusteringCompactionWorkerPoolSize.Key) diff --git a/tests/integration/minicluster_v2.go b/tests/integration/minicluster_v2.go index 5c87bf09fe..d2bd5e5c7c 100644 --- a/tests/integration/minicluster_v2.go +++ b/tests/integration/minicluster_v2.go @@ -44,11 +44,14 @@ import ( grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" grpcrootcoordclient "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/distributed/streamingnode" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" "github.com/milvus-io/milvus/internal/util/hookutil" + "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -57,18 +60,6 @@ import ( var params *paramtable.ComponentParam = paramtable.Get() -type ClusterConfig struct { - // ProxyNum int - // todo coord num can be more than 1 if enable Active-Standby - // RootCoordNum int - // DataCoordNum int - // IndexCoordNum int - // QueryCoordNum int - QueryNodeNum int - DataNodeNum int - IndexNodeNum int -} - func DefaultParams() map[string]string { testPath := fmt.Sprintf("integration-test-%d", time.Now().Unix()) return map[string]string{ @@ -83,21 +74,12 @@ func DefaultParams() map[string]string { } } -func DefaultClusterConfig() ClusterConfig { - return ClusterConfig{ - QueryNodeNum: 1, - DataNodeNum: 1, - IndexNodeNum: 1, - } -} - type MiniClusterV2 struct { ctx context.Context mu sync.RWMutex - params map[string]string - clusterConfig ClusterConfig + params map[string]string factory dependency.Factory ChunkManager storage.ChunkManager @@ -118,16 +100,18 @@ type MiniClusterV2 struct { QueryNodeClient types.QueryNodeClient IndexNodeClient types.IndexNodeClient - DataNode *grpcdatanode.Server - QueryNode *grpcquerynode.Server - IndexNode *grpcindexnode.Server + DataNode *grpcdatanode.Server + StreamingNode *streamingnode.Server + QueryNode *grpcquerynode.Server + IndexNode *grpcindexnode.Server - MetaWatcher MetaWatcher - ptmu sync.Mutex - querynodes []*grpcquerynode.Server - qnid atomic.Int64 - datanodes []*grpcdatanode.Server - dnid atomic.Int64 + MetaWatcher MetaWatcher + ptmu sync.Mutex + querynodes []*grpcquerynode.Server + qnid atomic.Int64 + datanodes []*grpcdatanode.Server + dnid atomic.Int64 + streamingnodes []*streamingnode.Server Extension *ReportChanExtension } @@ -144,7 +128,6 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, cluster.Extension = InitReportExtension() cluster.params = DefaultParams() - cluster.clusterConfig = DefaultClusterConfig() for _, opt := range opts { opt(cluster) } @@ -166,6 +149,10 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, } cluster.EtcdCli = etcdCli + if streamingutil.IsStreamingServiceEnabled() { + streaming.Init() + } + cluster.MetaWatcher = &EtcdMetaWatcher{ rootPath: etcdConfig.RootPath.GetValue(), etcdCli: cluster.EtcdCli, @@ -240,6 +227,12 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, if err != nil { return nil, err } + if streamingutil.IsStreamingServiceEnabled() { + cluster.StreamingNode, err = streamingnode.NewServer(cluster.factory) + if err != nil { + return nil, err + } + } cluster.QueryNode, err = grpcquerynode.NewServer(ctx, cluster.factory) if err != nil { return nil, err @@ -315,6 +308,22 @@ func (cluster *MiniClusterV2) AddDataNode() *grpcdatanode.Server { return node } +func (cluster *MiniClusterV2) AddStreamingNode() { + cluster.ptmu.Lock() + defer cluster.ptmu.Unlock() + + node, err := streamingnode.NewServer(cluster.factory) + if err != nil { + panic(err) + } + err = node.Run() + if err != nil { + panic(err) + } + + cluster.streamingnodes = append(cluster.streamingnodes, node) +} + func (cluster *MiniClusterV2) Start() error { log.Info("mini cluster start") err := cluster.RootCoord.Run() @@ -363,6 +372,14 @@ func (cluster *MiniClusterV2) Start() error { if !healthy { return errors.New("minicluster is not healthy after 120s") } + + if streamingutil.IsStreamingServiceEnabled() { + err = cluster.StreamingNode.Run() + if err != nil { + return err + } + } + log.Info("minicluster started") return nil } @@ -379,7 +396,13 @@ func (cluster *MiniClusterV2) Stop() error { log.Info("mini cluster proxy stopped") cluster.StopAllDataNodes() + cluster.StopAllStreamingNodes() cluster.StopAllQueryNodes() + + if streamingutil.IsStreamingServiceEnabled() { + streaming.Release() + } + cluster.IndexNode.Stop() log.Info("mini cluster indexNode stopped") @@ -429,6 +452,18 @@ func (cluster *MiniClusterV2) StopAllDataNodes() { log.Info(fmt.Sprintf("mini cluster stopped %d extra datanode", numExtraDN)) } +func (cluster *MiniClusterV2) StopAllStreamingNodes() { + if cluster.StreamingNode != nil { + cluster.StreamingNode.Stop() + log.Info("mini cluster main streamingnode stopped") + } + for _, node := range cluster.streamingnodes { + node.Stop() + } + log.Info(fmt.Sprintf("mini cluster stopped %d streaming nodes", len(cluster.streamingnodes))) + cluster.streamingnodes = nil +} + func (cluster *MiniClusterV2) GetContext() context.Context { return cluster.ctx } diff --git a/tests/integration/streaming/hello_streaming_test.go b/tests/integration/streaming/hello_streaming_test.go new file mode 100644 index 0000000000..721e7abde0 --- /dev/null +++ b/tests/integration/streaming/hello_streaming_test.go @@ -0,0 +1,202 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package streaming + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/streamingutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/tests/integration" +) + +type HelloStreamingSuite struct { + integration.MiniClusterSuite +} + +func (s *HelloStreamingSuite) SetupSuite() { + streamingutil.SetStreamingServiceEnabled() + s.MiniClusterSuite.SetupSuite() +} + +func (s *HelloStreamingSuite) TeardownSuite() { + s.MiniClusterSuite.TearDownSuite() + streamingutil.UnsetStreamingServiceEnabled() +} + +func (s *HelloStreamingSuite) TestHelloStreaming() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) + defer cancel() + c := s.Cluster + + const ( + dim = 128 + dbName = "" + rowNum = 100000 + + indexType = integration.IndexFaissIvfFlat + metricType = metric.L2 + vecType = schemapb.DataType_FloatVector + ) + + collectionName := "TestHelloStreaming_" + funcutil.GenRandomStr() + + schema := integration.ConstructSchemaOfVecDataType(collectionName, dim, false, vecType) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + // create collection + createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + ConsistencyLevel: commonpb.ConsistencyLevel_Strong, + }) + err = merr.CheckRPCCall(createCollectionStatus, err) + s.NoError(err) + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + + // show collection + showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + err = merr.CheckRPCCall(showCollectionsResp, err) + s.NoError(err) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + // insert + pkColumn := integration.NewInt64FieldData(integration.Int64Field, rowNum) + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + err = merr.CheckRPCCall(insertResult, err) + s.NoError(err) + s.Equal(int64(rowNum), insertResult.GetInsertCnt()) + + // delete + deleteResult, err := c.Proxy.Delete(ctx, &milvuspb.DeleteRequest{ + DbName: dbName, + CollectionName: collectionName, + Expr: integration.Int64Field + " in [1, 2]", + }) + err = merr.CheckRPCCall(deleteResult, err) + s.NoError(err) + + // flush + flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + err = merr.CheckRPCCall(flushResp, err) + s.NoError(err) + s.T().Logf("flush response, flushTs=%d, segmentIDs=%v", flushResp.GetCollFlushTs()[collectionName], flushResp.GetCollSegIDs()[collectionName]) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, indexType, metricType), + }) + err = merr.CheckRPCCall(createIndexStatus, err) + s.NoError(err) + s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + s.Equal(2, len(segments)) + s.Equal(int64(rowNum), segments[0].GetNumOfRows()) + + // load + loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + err = merr.CheckRPCCall(loadStatus, err) + s.NoError(err) + s.WaitForLoad(ctx, collectionName) + + // search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 10 + topk := 10 + roundDecimal := -1 + params := integration.GetSearchParams(indexType, metricType) + searchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.FloatVecField, vecType, nil, metricType, params, nq, dim, topk, roundDecimal) + + searchResult, err := c.Proxy.Search(ctx, searchReq) + err = merr.CheckRPCCall(searchResult, err) + s.NoError(err) + s.Equal(nq*topk, len(searchResult.GetResults().GetScores())) + + // query + queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{ + DbName: dbName, + CollectionName: collectionName, + Expr: "", + OutputFields: []string{"count(*)"}, + }) + err = merr.CheckRPCCall(queryResult, err) + s.NoError(err) + s.Equal(int64(rowNum-2), queryResult.GetFieldsData()[0].GetScalars().GetLongData().GetData()[0]) + + // release collection + status, err := c.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{ + CollectionName: collectionName, + }) + err = merr.CheckRPCCall(status, err) + s.NoError(err) + + // drop collection + status, err = c.Proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{ + CollectionName: collectionName, + }) + err = merr.CheckRPCCall(status, err) + s.NoError(err) +} + +func TestHelloStreamingNode(t *testing.T) { + suite.Run(t, new(HelloStreamingSuite)) +}