diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 32161d849f..f6d489e921 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -404,6 +404,15 @@ common: ttl: 20 # ttl value when session granting a lease to register service retryTimes: 30 # retry times when session sending etcd requests + # preCreatedTopic decides whether using existed topic + preCreatedTopic: + enabled: false + # support pre-created topics + # the name of pre-created topics + names: ["topic1", "topic2"] + # need to set a separated topic to stand for currently consumed timestamp for each channel + timeticker: "timetick-channel" + # QuotaConfig, configurations of Milvus quota and limits. # By default, we enable: # 1. TT protection; diff --git a/deployments/docker/dev/docker-compose.yml b/deployments/docker/dev/docker-compose.yml index 4498437f05..640bd16426 100644 --- a/deployments/docker/dev/docker-compose.yml +++ b/deployments/docker/dev/docker-compose.yml @@ -49,7 +49,13 @@ services: - ${DOCKER_VOLUME_DIRECTORY:-.}/volumes/minio:/minio_data command: minio server /minio_data --console-address ":9001" healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + test: + [ + "CMD", + "curl", + "-f", + "http://localhost:9000/minio/health/live" + ] interval: 30s timeout: 20s retries: 3 diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index 9e734ee11f..8e0970257a 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -570,11 +570,16 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) { log.Error("DataCoord failed to create timetick channel", zap.Error(err)) panic(err) } + + timeTickChannel := Params.CommonCfg.DataCoordTimeTick.GetValue() + if Params.CommonCfg.PreCreatedTopicEnabled.GetAsBool() { + timeTickChannel = Params.CommonCfg.TimeTicker.GetValue() + } subName := fmt.Sprintf("%s-%d-datanodeTl", Params.CommonCfg.DataCoordSubName.GetValue(), paramtable.GetNodeID()) - ttMsgStream.AsConsumer([]string{Params.CommonCfg.DataCoordTimeTick.GetValue()}, - subName, mqwrapper.SubscriptionPositionLatest) + + ttMsgStream.AsConsumer([]string{timeTickChannel}, subName, mqwrapper.SubscriptionPositionLatest) log.Info("DataCoord creates the timetick channel consumer", - zap.String("timeTickChannel", Params.CommonCfg.DataCoordTimeTick.GetValue()), + zap.String("timeTickChannel", timeTickChannel), zap.String("subscription", subName)) go s.handleDataNodeTimetickMsgstream(ctx, ttMsgStream) diff --git a/internal/datanode/flow_graph_dmstream_input_node.go b/internal/datanode/flow_graph_dmstream_input_node.go index a80ba2c77d..75149c5f05 100644 --- a/internal/datanode/flow_graph_dmstream_input_node.go +++ b/internal/datanode/flow_graph_dmstream_input_node.go @@ -36,8 +36,8 @@ import ( // DmInputNode receives messages from message streams, packs messages between two timeticks, and passes all // -// messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is -// flowgraph ddNode. +// messages between two timeticks to the following flowgraph node. In DataNode, the following flow graph node is +// flowgraph ddNode. func newDmInputNode(dispatcherClient msgdispatcher.Client, seekPos *msgpb.MsgPosition, dmNodeConfig *nodeConfig) (*flowgraph.InputNode, error) { log := log.With(zap.Int64("nodeID", paramtable.GetNodeID()), zap.Int64("collection ID", dmNodeConfig.collectionID), diff --git a/internal/datanode/flow_graph_dmstream_input_node_test.go b/internal/datanode/flow_graph_dmstream_input_node_test.go index 129be01f3f..85490c6832 100644 --- a/internal/datanode/flow_graph_dmstream_input_node_test.go +++ b/internal/datanode/flow_graph_dmstream_input_node_test.go @@ -100,6 +100,10 @@ func (mtm *mockTtMsgStream) GetLatestMsgID(channel string) (msgstream.MessageID, return nil, nil } +func (mtm *mockTtMsgStream) CheckTopicValid(channel string) error { + return nil +} + func TestNewDmInputNode(t *testing.T) { client := msgdispatcher.NewClient(&mockMsgStreamFactory{}, typeutil.DataNodeRole, paramtable.GetNodeID()) _, err := newDmInputNode(client, new(msgpb.MsgPosition), &nodeConfig{ diff --git a/internal/mq/mqimpl/rocksmq/client/consumer.go b/internal/mq/mqimpl/rocksmq/client/consumer.go index 0765d8fd7a..eaf769a783 100644 --- a/internal/mq/mqimpl/rocksmq/client/consumer.go +++ b/internal/mq/mqimpl/rocksmq/client/consumer.go @@ -72,4 +72,7 @@ type Consumer interface { // GetLatestMsgID get the latest msgID GetLatestMsgID() (int64, error) + + // check created topic whether vaild or not + CheckTopicValid(topic string) error } diff --git a/internal/mq/mqimpl/rocksmq/client/consumer_impl.go b/internal/mq/mqimpl/rocksmq/client/consumer_impl.go index dea745d8b9..49a7375e8a 100644 --- a/internal/mq/mqimpl/rocksmq/client/consumer_impl.go +++ b/internal/mq/mqimpl/rocksmq/client/consumer_impl.go @@ -140,6 +140,10 @@ func (c *consumer) GetLatestMsgID() (int64, error) { if err != nil { return msgID, err } - return msgID, nil } + +func (c *consumer) CheckTopicValid(topic string) error { + err := c.client.server.CheckTopicValid(topic) + return err +} diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq.go b/internal/mq/mqimpl/rocksmq/server/rocksmq.go index 65e580fb3f..dee82f3147 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq.go @@ -42,6 +42,7 @@ type RocksMQ interface { RegisterConsumer(consumer *Consumer) error GetLatestMsg(topicName string) (int64, error) + CheckTopicValid(topicName string) error Produce(topicName string, messages []ProducerMessage) ([]UniqueID, error) Consume(topicName string, groupName string, n int) ([]ConsumerMessage, error) diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go index dd119e629e..f390d98151 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/internal/util/hardware" + "github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/typeutil" @@ -1074,3 +1075,24 @@ func (rmq *rocksmq) updateAckedInfo(topicName, groupName string, firstID UniqueI } return nil } + +func (rmq *rocksmq) CheckTopicValid(topic string) error { + // Check if key exists + log := log.With(zap.String("topic", topic)) + + _, ok := topicMu.Load(topic) + if !ok { + return merr.WrapErrTopicNotFound(topic, "failed to get topic") + } + + latestMsgID, err := rmq.GetLatestMsg(topic) + if err != nil { + return err + } + + if latestMsgID != DefaultMessageID { + return merr.WrapErrTopicNotEmpty(topic, "topic is not empty") + } + log.Info("created topic is empty") + return nil +} diff --git a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go index e45864bba2..4abac56143 100644 --- a/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go +++ b/internal/mq/mqimpl/rocksmq/server/rocksmq_impl_test.go @@ -30,6 +30,7 @@ import ( rocksdbkv "github.com/milvus-io/milvus/internal/kv/rocksdb" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/util/etcd" + "github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/tecbot/gorocksdb" "go.uber.org/zap" @@ -995,6 +996,55 @@ func TestRocksmq_GetLatestMsg(t *testing.T) { assert.NotNil(t, err) } +func TestRocksmq_CheckPreTopicValid(t *testing.T) { + suffix := "_topic" + kvPath := rmqPath + kvPathSuffix + suffix + defer os.RemoveAll(kvPath) + idAllocator := InitIDAllocator(kvPath) + + rocksdbPath := rmqPath + suffix + defer os.RemoveAll(rocksdbPath + kvSuffix) + defer os.RemoveAll(rocksdbPath) + paramtable.Init() + rmq, err := NewRocksMQ(rocksdbPath, idAllocator) + assert.Nil(t, err) + defer rmq.Close() + + channelName1 := "topic1" + // topic not exist + err = rmq.CheckTopicValid(channelName1) + assert.Equal(t, true, errors.Is(err, merr.ErrTopicNotFound)) + + channelName2 := "topic2" + // topic is not empty + err = rmq.CreateTopic(channelName2) + defer rmq.DestroyTopic(channelName2) + assert.Nil(t, err) + topicMu.Store(channelName2, new(sync.Mutex)) + + pMsgs := make([]ProducerMessage, 10) + for i := 0; i < 10; i++ { + msg := "message_" + strconv.Itoa(i) + pMsg := ProducerMessage{Payload: []byte(msg)} + pMsgs[i] = pMsg + } + _, err = rmq.Produce(channelName2, pMsgs) + assert.NoError(t, err) + + err = rmq.CheckTopicValid(channelName2) + assert.Equal(t, true, errors.Is(err, merr.ErrTopicNotEmpty)) + + channelName3 := "topic3" + // pass + err = rmq.CreateTopic(channelName3) + defer rmq.DestroyTopic(channelName3) + assert.Nil(t, err) + + topicMu.Store(channelName3, new(sync.Mutex)) + err = rmq.CheckTopicValid(channelName3) + assert.NoError(t, err) +} + func TestRocksmq_Close(t *testing.T) { ep := etcdEndpoints() etcdCli, err := etcd.GetRemoteEtcdClient(ep) diff --git a/internal/mq/msgstream/mock_msgstream.go b/internal/mq/msgstream/mock_msgstream.go index e5eaeb9b78..53e9e97abf 100644 --- a/internal/mq/msgstream/mock_msgstream.go +++ b/internal/mq/msgstream/mock_msgstream.go @@ -1,12 +1,11 @@ -// Code generated by mockery v2.15.0. DO NOT EDIT. +// Code generated by mockery v2.23.1. DO NOT EDIT. package msgstream import ( - "github.com/milvus-io/milvus-proto/go-api/msgpb" - mock "github.com/stretchr/testify/mock" - + msgpb "github.com/milvus-io/milvus-proto/go-api/msgpb" mqwrapper "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" + mock "github.com/stretchr/testify/mock" ) // MockMsgStream is an autogenerated mock type for the MsgStream type @@ -52,6 +51,11 @@ func (_c *MockMsgStream_AsConsumer_Call) Return() *MockMsgStream_AsConsumer_Call return _c } +func (_c *MockMsgStream_AsConsumer_Call) RunAndReturn(run func([]string, string, mqwrapper.SubscriptionInitialPosition)) *MockMsgStream_AsConsumer_Call { + _c.Call.Return(run) + return _c +} + // AsProducer provides a mock function with given fields: channels func (_m *MockMsgStream) AsProducer(channels []string) { _m.Called(channels) @@ -80,11 +84,20 @@ func (_c *MockMsgStream_AsProducer_Call) Return() *MockMsgStream_AsProducer_Call return _c } +func (_c *MockMsgStream_AsProducer_Call) RunAndReturn(run func([]string)) *MockMsgStream_AsProducer_Call { + _c.Call.Return(run) + return _c +} + // Broadcast provides a mock function with given fields: _a0 func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]mqwrapper.MessageID, error) { ret := _m.Called(_a0) var r0 map[string][]mqwrapper.MessageID + var r1 error + if rf, ok := ret.Get(0).(func(*MsgPack) (map[string][]mqwrapper.MessageID, error)); ok { + return rf(_a0) + } if rf, ok := ret.Get(0).(func(*MsgPack) map[string][]mqwrapper.MessageID); ok { r0 = rf(_a0) } else { @@ -93,7 +106,6 @@ func (_m *MockMsgStream) Broadcast(_a0 *MsgPack) (map[string][]mqwrapper.Message } } - var r1 error if rf, ok := ret.Get(1).(func(*MsgPack) error); ok { r1 = rf(_a0) } else { @@ -126,6 +138,11 @@ func (_c *MockMsgStream_Broadcast_Call) Return(_a0 map[string][]mqwrapper.Messag return _c } +func (_c *MockMsgStream_Broadcast_Call) RunAndReturn(run func(*MsgPack) (map[string][]mqwrapper.MessageID, error)) *MockMsgStream_Broadcast_Call { + _c.Call.Return(run) + return _c +} + // Chan provides a mock function with given fields: func (_m *MockMsgStream) Chan() <-chan *MsgPack { ret := _m.Called() @@ -164,6 +181,53 @@ func (_c *MockMsgStream_Chan_Call) Return(_a0 <-chan *MsgPack) *MockMsgStream_Ch return _c } +func (_c *MockMsgStream_Chan_Call) RunAndReturn(run func() <-chan *MsgPack) *MockMsgStream_Chan_Call { + _c.Call.Return(run) + return _c +} + +// CheckTopicValid provides a mock function with given fields: channel +func (_m *MockMsgStream) CheckTopicValid(channel string) error { + ret := _m.Called(channel) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(channel) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockMsgStream_CheckTopicValid_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckTopicValid' +type MockMsgStream_CheckTopicValid_Call struct { + *mock.Call +} + +// CheckTopicValid is a helper method to define mock.On call +// - channel string +func (_e *MockMsgStream_Expecter) CheckTopicValid(channel interface{}) *MockMsgStream_CheckTopicValid_Call { + return &MockMsgStream_CheckTopicValid_Call{Call: _e.mock.On("CheckTopicValid", channel)} +} + +func (_c *MockMsgStream_CheckTopicValid_Call) Run(run func(channel string)) *MockMsgStream_CheckTopicValid_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockMsgStream_CheckTopicValid_Call) Return(_a0 error) *MockMsgStream_CheckTopicValid_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMsgStream_CheckTopicValid_Call) RunAndReturn(run func(string) error) *MockMsgStream_CheckTopicValid_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: func (_m *MockMsgStream) Close() { _m.Called() @@ -191,11 +255,20 @@ func (_c *MockMsgStream_Close_Call) Return() *MockMsgStream_Close_Call { return _c } +func (_c *MockMsgStream_Close_Call) RunAndReturn(run func()) *MockMsgStream_Close_Call { + _c.Call.Return(run) + return _c +} + // GetLatestMsgID provides a mock function with given fields: channel func (_m *MockMsgStream) GetLatestMsgID(channel string) (mqwrapper.MessageID, error) { ret := _m.Called(channel) var r0 mqwrapper.MessageID + var r1 error + if rf, ok := ret.Get(0).(func(string) (mqwrapper.MessageID, error)); ok { + return rf(channel) + } if rf, ok := ret.Get(0).(func(string) mqwrapper.MessageID); ok { r0 = rf(channel) } else { @@ -204,7 +277,6 @@ func (_m *MockMsgStream) GetLatestMsgID(channel string) (mqwrapper.MessageID, er } } - var r1 error if rf, ok := ret.Get(1).(func(string) error); ok { r1 = rf(channel) } else { @@ -237,6 +309,11 @@ func (_c *MockMsgStream_GetLatestMsgID_Call) Return(_a0 mqwrapper.MessageID, _a1 return _c } +func (_c *MockMsgStream_GetLatestMsgID_Call) RunAndReturn(run func(string) (mqwrapper.MessageID, error)) *MockMsgStream_GetLatestMsgID_Call { + _c.Call.Return(run) + return _c +} + // GetProduceChannels provides a mock function with given fields: func (_m *MockMsgStream) GetProduceChannels() []string { ret := _m.Called() @@ -275,6 +352,11 @@ func (_c *MockMsgStream_GetProduceChannels_Call) Return(_a0 []string) *MockMsgSt return _c } +func (_c *MockMsgStream_GetProduceChannels_Call) RunAndReturn(run func() []string) *MockMsgStream_GetProduceChannels_Call { + _c.Call.Return(run) + return _c +} + // Produce provides a mock function with given fields: _a0 func (_m *MockMsgStream) Produce(_a0 *MsgPack) error { ret := _m.Called(_a0) @@ -312,6 +394,11 @@ func (_c *MockMsgStream_Produce_Call) Return(_a0 error) *MockMsgStream_Produce_C return _c } +func (_c *MockMsgStream_Produce_Call) RunAndReturn(run func(*MsgPack) error) *MockMsgStream_Produce_Call { + _c.Call.Return(run) + return _c +} + // Seek provides a mock function with given fields: offset func (_m *MockMsgStream) Seek(offset []*msgpb.MsgPosition) error { ret := _m.Called(offset) @@ -349,6 +436,11 @@ func (_c *MockMsgStream_Seek_Call) Return(_a0 error) *MockMsgStream_Seek_Call { return _c } +func (_c *MockMsgStream_Seek_Call) RunAndReturn(run func([]*msgpb.MsgPosition) error) *MockMsgStream_Seek_Call { + _c.Call.Return(run) + return _c +} + // SetRepackFunc provides a mock function with given fields: repackFunc func (_m *MockMsgStream) SetRepackFunc(repackFunc RepackFunc) { _m.Called(repackFunc) @@ -377,6 +469,11 @@ func (_c *MockMsgStream_SetRepackFunc_Call) Return() *MockMsgStream_SetRepackFun return _c } +func (_c *MockMsgStream_SetRepackFunc_Call) RunAndReturn(run func(RepackFunc)) *MockMsgStream_SetRepackFunc_Call { + _c.Call.Return(run) + return _c +} + type mockConstructorTestingTNewMockMsgStream interface { mock.TestingT Cleanup(func()) diff --git a/internal/mq/msgstream/mq_msgstream.go b/internal/mq/msgstream/mq_msgstream.go index 451a4fe276..ca61bedac8 100644 --- a/internal/mq/msgstream/mq_msgstream.go +++ b/internal/mq/msgstream/mq_msgstream.go @@ -100,6 +100,7 @@ func (ms *mqMsgStream) AsProducer(channels []string) { log.Error("MsgStream asProducer's channel is an empty string") break } + fn := func() error { pp, err := ms.client.CreateProducer(mqwrapper.ProducerOptions{Topic: channel, EnableCompression: true}) if err != nil { @@ -132,6 +133,14 @@ func (ms *mqMsgStream) GetLatestMsgID(channel string) (MessageID, error) { return lastMsg, nil } +func (ms *mqMsgStream) CheckTopicValid(channel string) error { + err := ms.consumers[channel].CheckTopicValid(channel) + if err != nil { + return err + } + return nil +} + // AsConsumerWithPosition Create consumer to receive message from channels, with initial position // if initial position is set to latest, last message in the channel is exclusive func (ms *mqMsgStream) AsConsumer(channels []string, subName string, position mqwrapper.SubscriptionInitialPosition) { diff --git a/internal/mq/msgstream/mqwrapper/consumer.go b/internal/mq/msgstream/mqwrapper/consumer.go index a2bd8a4087..4046f8d753 100644 --- a/internal/mq/msgstream/mqwrapper/consumer.go +++ b/internal/mq/msgstream/mqwrapper/consumer.go @@ -70,4 +70,7 @@ type Consumer interface { // GetLatestMsgID return the latest message ID GetLatestMsgID() (MessageID, error) + + // check created topic whether vaild or not + CheckTopicValid(channel string) error } diff --git a/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer.go b/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer.go index dcc575472e..5d654676c7 100644 --- a/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer.go +++ b/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer.go @@ -5,10 +5,10 @@ import ( "time" "github.com/cockroachdb/errors" - "github.com/confluentinc/confluent-kafka-go/kafka" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" + "github.com/milvus-io/milvus/internal/util/merr" "go.uber.org/zap" ) @@ -220,6 +220,30 @@ func (kc *Consumer) GetLatestMsgID() (mqwrapper.MessageID, error) { return &kafkaID{messageID: high}, nil } +func (kc *Consumer) CheckTopicValid(topic string) error { + latestMsgID, err := kc.GetLatestMsgID() + log.With(zap.String("topic", kc.topic)) + // check topic is existed + if err != nil { + switch v := err.(type) { + case kafka.Error: + if v.Code() == kafka.ErrUnknownTopic || v.Code() == kafka.ErrUnknownPartition || v.Code() == kafka.ErrUnknownTopicOrPart { + return merr.WrapErrTopicNotFound(topic, "topic get latest msg ID failed, topic or partition does not exists") + } + default: + return err + } + } + + // check topic is empty + if !latestMsgID.AtEarliestPosition() { + return merr.WrapErrTopicNotEmpty(topic, "topic is not empty") + } + log.Info("created topic is empty") + + return nil +} + func (kc *Consumer) Close() { kc.closeOnce.Do(func() { close(kc.closeCh) diff --git a/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go b/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go index 11877fa0b4..4ec6c3d907 100644 --- a/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go +++ b/internal/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go @@ -254,3 +254,17 @@ func createConfig(groupID string) *kafka.ConfigMap { "api.version.request": "true", } } + +func TestKafkaConsumer_CheckPreTopicValid(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + groupID := fmt.Sprintf("test-groupid-%d", rand.Int()) + topic := fmt.Sprintf("test-topicName-%d", rand.Int()) + + config := createConfig(groupID) + consumer, err := newKafkaConsumer(config, topic, groupID, mqwrapper.SubscriptionPositionEarliest) + assert.NoError(t, err) + defer consumer.Close() + + err = consumer.CheckTopicValid(topic) + assert.NoError(t, err) +} diff --git a/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go b/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go index eb3edf4250..578f3def8f 100644 --- a/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go +++ b/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer.go @@ -24,6 +24,7 @@ import ( "unsafe" "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" + "github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/retry" "github.com/apache/pulsar-client-go/pulsar" @@ -152,6 +153,20 @@ func (pc *Consumer) GetLatestMsgID() (mqwrapper.MessageID, error) { return &pulsarID{messageID: msgID}, err } +func (pc *Consumer) CheckTopicValid(topic string) error { + latestMsgID, err := pc.GetLatestMsgID() + // Pulsar creates that topic under the namespace provided in the topic name automatically + if err != nil { + return err + } + + if !latestMsgID.AtEarliestPosition() { + return merr.WrapErrTopicNotEmpty(topic, "topic is not empty") + } + log.Info("created topic is empty", zap.String("topic", topic)) + return nil +} + // patchEarliestMessageID unsafe patch logic to change messageID partitionIdx to 0 // ONLY used in Chan() function // DON'T use elsewhere diff --git a/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go b/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go index 969484f7cb..79ef129f70 100644 --- a/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go +++ b/internal/mq/msgstream/mqwrapper/pulsar/pulsar_consumer_test.go @@ -218,3 +218,26 @@ func TestPulsarClientUnsubscribeTwice(t *testing.T) { assert.True(t, strings.Contains(err.Error(), "Consumer not found")) t.Log(err) } + +func TestCheckPreTopicValid(t *testing.T) { + pulsarAddress := getPulsarAddress() + pc, err := NewClient(DefaultPulsarTenant, DefaultPulsarNamespace, pulsar.ClientOptions{URL: pulsarAddress}) + assert.Nil(t, err) + + receiveChannel := make(chan pulsar.ConsumerMessage, 100) + consumer, err := pc.client.Subscribe(pulsar.ConsumerOptions{ + Topic: "Topic-1", + SubscriptionName: "SubName-1", + SubscriptionInitialPosition: pulsar.SubscriptionInitialPosition(mqwrapper.SubscriptionPositionEarliest), + MessageChannel: receiveChannel, + }) + assert.Nil(t, err) + assert.NotNil(t, consumer) + + str := consumer.Subscription() + assert.NotNil(t, str) + + pulsarConsumer := &Consumer{c: consumer, closeCh: make(chan struct{})} + err = pulsarConsumer.CheckTopicValid("Topic-1") + assert.NoError(t, err) +} diff --git a/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go b/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go index 505a60e371..4aa040f9b5 100644 --- a/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go +++ b/internal/mq/msgstream/mqwrapper/rmq/rmq_consumer.go @@ -101,3 +101,7 @@ func (rc *Consumer) GetLatestMsgID() (mqwrapper.MessageID, error) { msgID, err := rc.c.GetLatestMsgID() return &rmqID{messageID: msgID}, err } + +func (rc *Consumer) CheckTopicValid(topic string) error { + return rc.c.CheckTopicValid(topic) +} diff --git a/internal/mq/msgstream/msgstream.go b/internal/mq/msgstream/msgstream.go index 2bea441f43..33be2c3e5a 100644 --- a/internal/mq/msgstream/msgstream.go +++ b/internal/mq/msgstream/msgstream.go @@ -66,6 +66,7 @@ type MsgStream interface { Seek(offset []*MsgPosition) error GetLatestMsgID(channel string) (MessageID, error) + CheckTopicValid(channel string) error } type Factory interface { diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index bf85bedf45..13c3828b8c 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -296,6 +296,10 @@ func (ms *simpleMockMsgStream) GetLatestMsgID(channel string) (msgstream.Message return nil, nil } +func (ms *simpleMockMsgStream) CheckTopicValid(topic string) error { + return nil +} + func newSimpleMockMsgStream() *simpleMockMsgStream { return &simpleMockMsgStream{ msgChan: make(chan *msgstream.MsgPack, 1024), diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index acc9a9c228..b5cfe56cd2 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -23,11 +23,13 @@ import ( "sync" "github.com/milvus-io/milvus/internal/metrics" + "github.com/milvus-io/milvus/internal/util/paramtable" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/mq/msgstream" + "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" ) type dmlMsgStream struct { @@ -142,7 +144,25 @@ type dmlChannels struct { channelsHeap channelsHeap } -func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePrefix string, chanNum int64) *dmlChannels { +func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePrefixDefault string, chanNumDefault int64) *dmlChannels { + params := paramtable.Get().CommonCfg + var ( + chanNamePrefix string + chanNum int64 + names []string + ) + + // if topic created, use the existed topic + if params.PreCreatedTopicEnabled.GetAsBool() { + chanNamePrefix = "" + chanNum = int64(len(params.TopicNames.GetAsStrings())) + names = params.TopicNames.GetAsStrings() + } else { + chanNamePrefix = chanNamePrefixDefault + chanNum = chanNumDefault + names = genChannelNames(chanNamePrefix, chanNum) + } + d := &dmlChannels{ ctx: ctx, factory: factory, @@ -151,21 +171,35 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref channelsHeap: make([]*dmlMsgStream, 0, chanNum), } - for i := int64(0); i < chanNum; i++ { - name := genChannelName(d.namePrefix, i) + for i, name := range names { ms, err := factory.NewMsgStream(ctx) if err != nil { - log.Error("Failed to add msgstream", zap.String("name", name), zap.Error(err)) + log.Error("Failed to add msgstream", + zap.String("name", name), + zap.Error(err)) panic("Failed to add msgstream") } - ms.AsProducer([]string{name}) + if params.PreCreatedTopicEnabled.GetAsBool() { + subName := fmt.Sprintf("pre-created-topic-check-%s", name) + ms.AsConsumer([]string{name}, subName, mqwrapper.SubscriptionPositionUnknown) + // check topic exist and check the existed topic whether empty or not + // kafka and rmq will err if the topic does not yet exist, pulsar will not + // if one of the topics is not empty, panic + err := ms.CheckTopicValid(name) + if err != nil { + log.Error("created topic is invaild", zap.String("name", name), zap.Error(err)) + panic("created topic is invaild") + } + } + + ms.AsProducer([]string{name}) dms := &dmlMsgStream{ ms: ms, refcnt: 0, used: 0, - idx: i, - pos: int(i), + idx: int64(i), + pos: i, } d.pool.Store(name, dms) d.channelsHeap = append(d.channelsHeap, dms) @@ -194,7 +228,7 @@ func (d *dmlChannels) getChannelNames(count int) []string { item := heap.Pop(&d.channelsHeap).(*dmlMsgStream) item.BookUsage() items = append(items, item) - result = append(result, genChannelName(d.namePrefix, item.idx)) + result = append(result, getChannelName(d.namePrefix, item.idx)) } for _, item := range items { @@ -211,7 +245,7 @@ func (d *dmlChannels) listChannels() []string { func(k, v interface{}) bool { dms := v.(*dmlMsgStream) if dms.RefCnt() > 0 { - chanNames = append(chanNames, genChannelName(d.namePrefix, dms.idx)) + chanNames = append(chanNames, getChannelName(d.namePrefix, dms.idx)) } return true }) @@ -306,6 +340,19 @@ func (d *dmlChannels) removeChannels(names ...string) { } } -func genChannelName(prefix string, idx int64) string { +func getChannelName(prefix string, idx int64) string { + params := paramtable.Get().CommonCfg + if params.PreCreatedTopicEnabled.GetAsBool() { + return params.TopicNames.GetAsStrings()[idx] + } return fmt.Sprintf("%s_%d", prefix, idx) } + +func genChannelNames(prefix string, num int64) []string { + var results []string + for idx := int64(0); idx < num; idx++ { + result := fmt.Sprintf("%s_%d", prefix, idx) + results = append(results, result) + } + return results +} diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index d8c4a618e6..66225aa904 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/mq/msgstream/mqwrapper" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -158,6 +159,13 @@ func TestDmlChannels(t *testing.T) { dml.removeChannels(chans0...) assert.Equal(t, 0, dml.getChannelNum()) + + paramtable.Get().Save(Params.CommonCfg.PreCreatedTopicEnabled.Key, "true") + paramtable.Get().Save(Params.CommonCfg.TopicNames.Key, "topic1,topic2") + defer paramtable.Get().Reset(Params.CommonCfg.PreCreatedTopicEnabled.Key) + defer paramtable.Get().Reset(Params.CommonCfg.TopicNames.Key) + + assert.Panics(t, func() { newDmlChannels(ctx, factory, dmlChanPrefix, totalDmlChannelNum) }) } func TestDmChannelsFailure(t *testing.T) { diff --git a/internal/util/merr/errors.go b/internal/util/merr/errors.go index 1b2e235ce8..9dd09bcc33 100644 --- a/internal/util/merr/errors.go +++ b/internal/util/merr/errors.go @@ -99,6 +99,10 @@ var ( // Metrics related ErrMetricNotFound = newMilvusError("metric not found", 1200, false) + // Topic related + ErrTopicNotFound = newMilvusError("topic not found", 1300, false) + ErrTopicNotEmpty = newMilvusError("topic not empty", 1301, false) + // Do NOT export this, // never allow programmer using this, keep only for converting unknown error to milvusError errUnexpected = newMilvusError("unexpected error", (1<<16)-1, false) diff --git a/internal/util/merr/errors_test.go b/internal/util/merr/errors_test.go index f25fe4a618..16144c296b 100644 --- a/internal/util/merr/errors_test.go +++ b/internal/util/merr/errors_test.go @@ -108,6 +108,11 @@ func (s *ErrSuite) TestWrap() { // Metrics related s.ErrorIs(WrapErrMetricNotFound("unknown", "failed to get metric"), ErrMetricNotFound) + + // Topic related + s.ErrorIs(WrapErrTopicNotFound("unknown", "failed to get topic"), ErrTopicNotFound) + s.ErrorIs(WrapErrTopicNotEmpty("unknown", "topic is not empty"), ErrTopicNotEmpty) + } func (s *ErrSuite) TestCombine() { diff --git a/internal/util/merr/utils.go b/internal/util/merr/utils.go index 468c59b4bf..512846af67 100644 --- a/internal/util/merr/utils.go +++ b/internal/util/merr/utils.go @@ -364,6 +364,23 @@ func WrapErrMetricNotFound(name string, msg ...string) error { return err } +// Topic related +func WrapErrTopicNotFound(name string, msg ...string) error { + err := errors.Wrapf(ErrTopicNotFound, "topic=%s", name) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + +func WrapErrTopicNotEmpty(name string, msg ...string) error { + err := errors.Wrapf(ErrTopicNotEmpty, "topic=%s", name) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "; ")) + } + return err +} + func wrapWithField(err error, name string, value any) error { return errors.Wrapf(err, "%s=%v", name, value) } diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index 645a5d3129..070212f69d 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -211,6 +211,10 @@ type commonConfig struct { SessionTTL ParamItem `refreshable:"false"` SessionRetryTimes ParamItem `refreshable:"false"` + + PreCreatedTopicEnabled ParamItem `refreshable:"true"` + TopicNames ParamItem `refreshable:"true"` + TimeTicker ParamItem `refreshable:"true"` } func (p *commonConfig) init(base *BaseTable) { @@ -613,6 +617,24 @@ like the old password verification when updating the credential`, } p.SessionRetryTimes.Init(base.mgr) + p.PreCreatedTopicEnabled = ParamItem{ + Key: "common.preCreatedTopic.enabled", + Version: "2.3.0", + DefaultValue: "false", + } + p.PreCreatedTopicEnabled.Init(base.mgr) + + p.TopicNames = ParamItem{ + Key: "common.preCreatedTopic.names", + Version: "2.3.0", + } + p.TopicNames.Init(base.mgr) + + p.TimeTicker = ParamItem{ + Key: "common.preCreatedTopic.timeticker", + Version: "2.3.0", + } + p.TimeTicker.Init(base.mgr) } type traceConfig struct { diff --git a/internal/util/paramtable/component_param_test.go b/internal/util/paramtable/component_param_test.go index 4d8c637262..c74b18e42d 100644 --- a/internal/util/paramtable/component_param_test.go +++ b/internal/util/paramtable/component_param_test.go @@ -122,6 +122,14 @@ func TestComponentParam(t *testing.T) { params.Save("common.security.superUsers", "") assert.Equal(t, []string{""}, Params.SuperUsers.GetAsStrings()) + + assert.Equal(t, false, Params.PreCreatedTopicEnabled.GetAsBool()) + + params.Save("common.preCreatedTopic.names", "topic1,topic2,topic3") + assert.Equal(t, []string{"topic1", "topic2", "topic3"}, Params.TopicNames.GetAsStrings()) + + params.Save("common.preCreatedTopic.timeticker", "timeticker") + assert.Equal(t, []string{"timeticker"}, Params.TimeTicker.GetAsStrings()) }) t.Run("test rootCoordConfig", func(t *testing.T) {