diff --git a/internal/streamingnode/server/server.go b/internal/streamingnode/server/server.go index 52413071da..8956d8d78e 100644 --- a/internal/streamingnode/server/server.go +++ b/internal/streamingnode/server/server.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + _ "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/kafka" _ "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/pulsar" _ "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" ) diff --git a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go index 8f17d68773..6e293353c4 100644 --- a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go @@ -114,7 +114,7 @@ func (s *scannerAdaptorImpl) executeConsume() { Message: s.pendingQueue.Next(), }) if handleResult.Error != nil { - s.Finish(err) + s.Finish(handleResult.Error) return } if handleResult.MessageHandled { diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go index f2d0bddb10..3755060b7e 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client.go @@ -241,7 +241,7 @@ func (kc *kafkaClient) Subscribe(ctx context.Context, options mqwrapper.Consumer } func (kc *kafkaClient) EarliestMessageID() common.MessageID { - return &kafkaID{messageID: int64(kafka.OffsetBeginning)} + return &KafkaID{MessageID: int64(kafka.OffsetBeginning)} } func (kc *kafkaClient) StringToMsgID(id string) (common.MessageID, error) { @@ -250,7 +250,7 @@ func (kc *kafkaClient) StringToMsgID(id string) (common.MessageID, error) { return nil, err } - return &kafkaID{messageID: offset}, nil + return &KafkaID{MessageID: offset}, nil } func (kc *kafkaClient) specialExtraConfig(current *kafka.ConfigMap, special kafka.ConfigMap) { @@ -265,7 +265,7 @@ func (kc *kafkaClient) specialExtraConfig(current *kafka.ConfigMap, special kafk func (kc *kafkaClient) BytesToMsgID(id []byte) (common.MessageID, error) { offset := DeserializeKafkaID(id) - return &kafkaID{messageID: offset}, nil + return &KafkaID{MessageID: offset}, nil } func (kc *kafkaClient) Close() { diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go index 565fc67cad..27417ae56a 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_client_test.go @@ -196,7 +196,7 @@ func TestKafkaClient_ConsumeWithAck(t *testing.T) { Consume1(ctx1, t, kc, topic, subName, c, &total1) lastMsgID := <-c - log.Info("lastMsgID", zap.Any("lastMsgID", lastMsgID.(*kafkaID).messageID)) + log.Info("lastMsgID", zap.Any("lastMsgID", lastMsgID.(*KafkaID).MessageID)) ctx2, cancel2 := context.WithTimeout(ctx, 3*time.Second) Consume2(ctx2, t, kc, topic, subName, lastMsgID, &total2) diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go index bf87b260a7..8809d6b28f 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer.go @@ -74,7 +74,7 @@ func newKafkaConsumer(config *kafka.ConfigMap, bufSize int64, topic string, grou return nil, err } } else { - offset = kafka.Offset(latestMsgID.(*kafkaID).messageID) + offset = kafka.Offset(latestMsgID.(*KafkaID).MessageID) kc.skipMsg = true } } @@ -161,7 +161,7 @@ func (kc *Consumer) Seek(id common.MessageID, inclusive bool) error { return errors.New("kafka consumer is already assigned, can not seek again") } - offset := kafka.Offset(id.(*kafkaID).messageID) + offset := kafka.Offset(id.(*KafkaID).MessageID) return kc.internalSeek(offset, inclusive) } @@ -219,7 +219,7 @@ func (kc *Consumer) GetLatestMsgID() (common.MessageID, error) { } log.Info("get latest msg ID ", zap.String("topic", kc.topic), zap.Int64("oldest offset", low), zap.Int64("latest offset", high)) - return &kafkaID{messageID: high}, nil + return &KafkaID{MessageID: high}, nil } func (kc *Consumer) CheckTopicValid(topic string) error { diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go index 45bec8dad7..c058706f91 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_consumer_test.go @@ -40,14 +40,14 @@ func TestKafkaConsumer_SeekExclusive(t *testing.T) { data2 := []string{"111", "222", "333"} testKafkaConsumerProduceData(t, topic, data1, data2) - msgID := &kafkaID{messageID: 1} + msgID := &KafkaID{MessageID: 1} err = consumer.Seek(msgID, false) assert.NoError(t, err) msg := <-consumer.Chan() assert.Equal(t, 333, BytesToInt(msg.Payload())) assert.Equal(t, "333", msg.Properties()[common.TraceIDKey]) - assert.Equal(t, int64(2), msg.ID().(*kafkaID).messageID) + assert.Equal(t, int64(2), msg.ID().(*KafkaID).MessageID) assert.Equal(t, topic, msg.Topic()) assert.True(t, len(msg.Properties()) == 1) } @@ -66,14 +66,14 @@ func TestKafkaConsumer_SeekInclusive(t *testing.T) { data2 := []string{"111", "222", "333"} testKafkaConsumerProduceData(t, topic, data1, data2) - msgID := &kafkaID{messageID: 1} + msgID := &KafkaID{MessageID: 1} err = consumer.Seek(msgID, true) assert.NoError(t, err) msg := <-consumer.Chan() assert.Equal(t, 222, BytesToInt(msg.Payload())) assert.Equal(t, "222", msg.Properties()[common.TraceIDKey]) - assert.Equal(t, int64(1), msg.ID().(*kafkaID).messageID) + assert.Equal(t, int64(1), msg.ID().(*KafkaID).MessageID) assert.Equal(t, topic, msg.Topic()) assert.True(t, len(msg.Properties()) == 1) } @@ -88,7 +88,7 @@ func TestKafkaConsumer_GetSeek(t *testing.T) { assert.NoError(t, err) defer consumer.Close() - msgID := &kafkaID{messageID: 0} + msgID := &KafkaID{MessageID: 0} err = consumer.Seek(msgID, false) assert.NoError(t, err) @@ -163,7 +163,7 @@ func TestKafkaConsumer_GetLatestMsgID(t *testing.T) { defer consumer.Close() latestMsgID, err := consumer.GetLatestMsgID() - assert.Equal(t, int64(0), latestMsgID.(*kafkaID).messageID) + assert.Equal(t, int64(0), latestMsgID.(*KafkaID).MessageID) assert.NoError(t, err) data1 := []int{111, 222, 333} @@ -171,7 +171,7 @@ func TestKafkaConsumer_GetLatestMsgID(t *testing.T) { testKafkaConsumerProduceData(t, topic, data1, data2) latestMsgID, err = consumer.GetLatestMsgID() - assert.Equal(t, int64(2), latestMsgID.(*kafkaID).messageID) + assert.Equal(t, int64(2), latestMsgID.(*KafkaID).MessageID) assert.NoError(t, err) } diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_id.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_id.go index 2509065c1d..8f2d192673 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_id.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_id.go @@ -5,26 +5,32 @@ import ( mqcommon "github.com/milvus-io/milvus/pkg/mq/common" ) -type kafkaID struct { - messageID int64 +func NewKafkaID(messageID int64) mqcommon.MessageID { + return &KafkaID{ + MessageID: messageID, + } } -var _ mqcommon.MessageID = &kafkaID{} - -func (kid *kafkaID) Serialize() []byte { - return SerializeKafkaID(kid.messageID) +type KafkaID struct { + MessageID int64 } -func (kid *kafkaID) AtEarliestPosition() bool { - return kid.messageID <= 0 +var _ mqcommon.MessageID = &KafkaID{} + +func (kid *KafkaID) Serialize() []byte { + return SerializeKafkaID(kid.MessageID) } -func (kid *kafkaID) Equal(msgID []byte) (bool, error) { - return kid.messageID == DeserializeKafkaID(msgID), nil +func (kid *KafkaID) AtEarliestPosition() bool { + return kid.MessageID <= 0 } -func (kid *kafkaID) LessOrEqualThan(msgID []byte) (bool, error) { - return kid.messageID <= DeserializeKafkaID(msgID), nil +func (kid *KafkaID) Equal(msgID []byte) (bool, error) { + return kid.MessageID == DeserializeKafkaID(msgID), nil +} + +func (kid *KafkaID) LessOrEqualThan(msgID []byte) (bool, error) { + return kid.MessageID <= DeserializeKafkaID(msgID), nil } func SerializeKafkaID(messageID int64) []byte { diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_id_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_id_test.go index 29b501b66a..802fc7efa3 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_id_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_id_test.go @@ -7,24 +7,24 @@ import ( ) func TestKafkaID_Serialize(t *testing.T) { - rid := &kafkaID{messageID: 8} + rid := &KafkaID{MessageID: 8} bin := rid.Serialize() assert.NotNil(t, bin) assert.NotZero(t, len(bin)) } func TestKafkaID_AtEarliestPosition(t *testing.T) { - rid := &kafkaID{messageID: 8} + rid := &KafkaID{MessageID: 8} assert.False(t, rid.AtEarliestPosition()) - rid = &kafkaID{messageID: 0} + rid = &KafkaID{MessageID: 0} assert.True(t, rid.AtEarliestPosition()) } func TestKafkaID_LessOrEqualThan(t *testing.T) { { - rid1 := &kafkaID{messageID: 8} - rid2 := &kafkaID{messageID: 0} + rid1 := &KafkaID{MessageID: 8} + rid2 := &KafkaID{MessageID: 0} ret, err := rid1.LessOrEqualThan(rid2.Serialize()) assert.NoError(t, err) assert.False(t, ret) @@ -35,8 +35,8 @@ func TestKafkaID_LessOrEqualThan(t *testing.T) { } { - rid1 := &kafkaID{messageID: 0} - rid2 := &kafkaID{messageID: 0} + rid1 := &KafkaID{MessageID: 0} + rid2 := &KafkaID{MessageID: 0} ret, err := rid1.LessOrEqualThan(rid2.Serialize()) assert.NoError(t, err) assert.True(t, ret) @@ -44,8 +44,8 @@ func TestKafkaID_LessOrEqualThan(t *testing.T) { } func TestKafkaID_Equal(t *testing.T) { - rid1 := &kafkaID{messageID: 0} - rid2 := &kafkaID{messageID: 1} + rid1 := &KafkaID{MessageID: 0} + rid2 := &KafkaID{MessageID: 1} { ret, err := rid1.Equal(rid1.Serialize()) diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_message.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_message.go index cc33c8db40..93f611c9b6 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_message.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_message.go @@ -27,6 +27,6 @@ func (km *kafkaMessage) Payload() []byte { } func (km *kafkaMessage) ID() common.MessageID { - kid := &kafkaID{messageID: int64(km.msg.TopicPartition.Offset)} + kid := &KafkaID{MessageID: int64(km.msg.TopicPartition.Offset)} return kid } diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_message_test.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_message_test.go index 3fd9363251..379f2c6acf 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_message_test.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_message_test.go @@ -13,7 +13,7 @@ func TestKafkaMessage_All(t *testing.T) { km := &kafkaMessage{msg: msg} properties := make(map[string]string) assert.Equal(t, topic, km.Topic()) - assert.Equal(t, int64(0), km.ID().(*kafkaID).messageID) + assert.Equal(t, int64(0), km.ID().(*KafkaID).MessageID) assert.Nil(t, km.Payload()) assert.Equal(t, properties, km.Properties()) } diff --git a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go index e525244d5e..edd3016049 100644 --- a/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go +++ b/pkg/mq/msgstream/mqwrapper/kafka/kafka_producer.go @@ -75,7 +75,7 @@ func (kp *kafkaProducer) Send(ctx context.Context, message *mqcommon.ProducerMes metrics.MsgStreamRequestLatency.WithLabelValues(metrics.SendMsgLabel).Observe(float64(elapsed.Milliseconds())) metrics.MsgStreamOpCounter.WithLabelValues(metrics.SendMsgLabel, metrics.SuccessLabel).Inc() - return &kafkaID{messageID: int64(m.TopicPartition.Offset)}, nil + return &KafkaID{MessageID: int64(m.TopicPartition.Offset)}, nil } func (kp *kafkaProducer) Close() { diff --git a/pkg/streaming/util/message/adaptor/message_id.go b/pkg/streaming/util/message/adaptor/message_id.go index b9bc6dc333..1cd76ba1a8 100644 --- a/pkg/streaming/util/message/adaptor/message_id.go +++ b/pkg/streaming/util/message/adaptor/message_id.go @@ -4,11 +4,14 @@ import ( "fmt" "github.com/apache/pulsar-client-go/pulsar" + rawKafka "github.com/confluentinc/confluent-kafka-go/kafka" "github.com/milvus-io/milvus/pkg/mq/common" "github.com/milvus-io/milvus/pkg/mq/mqimpl/rocksmq/server" + mqkafka "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/kafka" mqpulsar "github.com/milvus-io/milvus/pkg/mq/msgstream/mqwrapper/pulsar" "github.com/milvus-io/milvus/pkg/streaming/util/message" + msgkafka "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/kafka" msgpulsar "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/pulsar" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" ) @@ -20,6 +23,8 @@ func MustGetMQWrapperIDFromMessage(messageID message.MessageID) common.MessageID return mqpulsar.NewPulsarID(id.PulsarID()) } else if id, ok := messageID.(interface{ RmqID() int64 }); ok { return &server.RmqID{MessageID: id.RmqID()} + } else if id, ok := messageID.(interface{ KafkaID() rawKafka.Offset }); ok { + return mqkafka.NewKafkaID(int64(id.KafkaID())) } panic("unsupported now") } @@ -31,6 +36,8 @@ func MustGetMessageIDFromMQWrapperID(commonMessageID common.MessageID) message.M return msgpulsar.NewPulsarID(id.PulsarID()) } else if id, ok := commonMessageID.(*server.RmqID); ok { return rmq.NewRmqID(id.MessageID) + } else if id, ok := commonMessageID.(*mqkafka.KafkaID); ok { + return msgkafka.NewKafkaID(rawKafka.Offset(id.MessageID)) } return nil } @@ -48,6 +55,9 @@ func DeserializeToMQWrapperID(msgID []byte, walName string) (common.MessageID, e case "rocksmq": rID := server.DeserializeRmqID(msgID) return &server.RmqID{MessageID: rID}, nil + case "kafka": + kID := mqkafka.DeserializeKafkaID(msgID) + return mqkafka.NewKafkaID(kID), nil default: return nil, fmt.Errorf("unsupported mq type %s", walName) } @@ -65,6 +75,9 @@ func MustGetMessageIDFromMQWrapperIDBytes(walName string, msgIDBytes []byte) mes panic(err) } commonMsgID = mqpulsar.NewPulsarID(msgID) + case "kafka": + id := mqkafka.DeserializeKafkaID(msgIDBytes) + commonMsgID = mqkafka.NewKafkaID(id) default: panic("unsupported now") } diff --git a/pkg/streaming/util/message/adaptor/message_id_test.go b/pkg/streaming/util/message/adaptor/message_id_test.go index 6b0944e8ce..81da9f5e87 100644 --- a/pkg/streaming/util/message/adaptor/message_id_test.go +++ b/pkg/streaming/util/message/adaptor/message_id_test.go @@ -6,6 +6,7 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/stretchr/testify/assert" + msgkafka "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/kafka" msgpulsar "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/pulsar" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/rmq" ) @@ -17,4 +18,7 @@ func TestIDConvension(t *testing.T) { msgID := pulsar.EarliestMessageID() id = MustGetMessageIDFromMQWrapperID(MustGetMQWrapperIDFromMessage(msgpulsar.NewPulsarID(msgID))) assert.True(t, id.EQ(msgpulsar.NewPulsarID(msgID))) + + kafkaID := MustGetMessageIDFromMQWrapperID(MustGetMQWrapperIDFromMessage(msgkafka.NewKafkaID(1))) + assert.True(t, kafkaID.EQ(msgkafka.NewKafkaID(1))) } diff --git a/pkg/streaming/walimpls/impls/kafka/builder.go b/pkg/streaming/walimpls/impls/kafka/builder.go new file mode 100644 index 0000000000..5c2cba15cd --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/builder.go @@ -0,0 +1,109 @@ +package kafka + +import ( + "github.com/confluentinc/confluent-kafka-go/kafka" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/registry" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +const ( + walName = "kafka" +) + +func init() { + // register the builder to the wal registry. + registry.RegisterBuilder(&builderImpl{}) + // register the unmarshaler to the message registry. + message.RegisterMessageIDUnmsarshaler(walName, UnmarshalMessageID) +} + +// builderImpl is the builder for pulsar wal. +type builderImpl struct{} + +// Name returns the name of the wal. +func (b *builderImpl) Name() string { + return walName +} + +// Build build a wal instance. +func (b *builderImpl) Build() (walimpls.OpenerImpls, error) { + producerConfig, consumerConfig := b.getProducerConfig(), b.getConsumerConfig() + + p, err := kafka.NewProducer(&producerConfig) + if err != nil { + return nil, err + } + return newOpenerImpl(p, consumerConfig), nil +} + +// getProducerAndConsumerConfig returns the producer and consumer config. +func (b *builderImpl) getProducerConfig() kafka.ConfigMap { + config := ¶mtable.Get().KafkaCfg + producerConfig := getBasicConfig(config) + + producerConfig.SetKey("message.max.bytes", 10485760) + producerConfig.SetKey("compression.codec", "zstd") + // we want to ensure tt send out as soon as possible + producerConfig.SetKey("linger.ms", 5) + for k, v := range config.ProducerExtraConfig.GetValue() { + producerConfig.SetKey(k, v) + } + return producerConfig +} + +func (b *builderImpl) getConsumerConfig() kafka.ConfigMap { + config := ¶mtable.Get().KafkaCfg + consumerConfig := getBasicConfig(config) + consumerConfig.SetKey("allow.auto.create.topics", true) + for k, v := range config.ConsumerExtraConfig.GetValue() { + consumerConfig.SetKey(k, v) + } + return consumerConfig +} + +// getBasicConfig returns the basic kafka config. +func getBasicConfig(config *paramtable.KafkaConfig) kafka.ConfigMap { + basicConfig := kafka.ConfigMap{ + "bootstrap.servers": config.Address.GetValue(), + "api.version.request": true, + "reconnect.backoff.ms": 20, + "reconnect.backoff.max.ms": 5000, + } + + if (config.SaslUsername.GetValue() == "" && config.SaslPassword.GetValue() != "") || + (config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() == "") { + panic("enable security mode need config username and password at the same time!") + } + + if config.SecurityProtocol.GetValue() != "" { + basicConfig.SetKey("security.protocol", config.SecurityProtocol.GetValue()) + } + + if config.SaslUsername.GetValue() != "" && config.SaslPassword.GetValue() != "" { + basicConfig.SetKey("sasl.mechanisms", config.SaslMechanisms.GetValue()) + basicConfig.SetKey("sasl.username", config.SaslUsername.GetValue()) + basicConfig.SetKey("sasl.password", config.SaslPassword.GetValue()) + } + + if config.KafkaUseSSL.GetAsBool() { + basicConfig.SetKey("ssl.certificate.location", config.KafkaTLSCert.GetValue()) + basicConfig.SetKey("ssl.key.location", config.KafkaTLSKey.GetValue()) + basicConfig.SetKey("ssl.ca.location", config.KafkaTLSCACert.GetValue()) + if config.KafkaTLSKeyPassword.GetValue() != "" { + basicConfig.SetKey("ssl.key.password", config.KafkaTLSKeyPassword.GetValue()) + } + } + return basicConfig +} + +// cloneKafkaConfig clones a kafka config. +func cloneKafkaConfig(config kafka.ConfigMap) kafka.ConfigMap { + newConfig := make(kafka.ConfigMap) + for k, v := range config { + newConfig[k] = v + } + return newConfig +} diff --git a/pkg/streaming/walimpls/impls/kafka/kafka_test.go b/pkg/streaming/walimpls/impls/kafka/kafka_test.go new file mode 100644 index 0000000000..fd8434d7ef --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/kafka_test.go @@ -0,0 +1,54 @@ +package kafka + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/registry" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestMain(m *testing.M) { + paramtable.Init() + m.Run() +} + +func TestRegistry(t *testing.T) { + registeredB := registry.MustGetBuilder(walName) + assert.NotNil(t, registeredB) + assert.Equal(t, walName, registeredB.Name()) + + id, err := message.UnmarshalMessageID(walName, + kafkaID(123).Marshal()) + assert.NoError(t, err) + assert.True(t, id.EQ(kafkaID(123))) +} + +func TestKafka(t *testing.T) { + walimpls.NewWALImplsTestFramework(t, 100, &builderImpl{}).Run() +} + +func TestGetBasicConfig(t *testing.T) { + config := ¶mtable.Get().KafkaCfg + oldSecurityProtocol := config.SecurityProtocol.SwapTempValue("test") + oldSaslUsername := config.SaslUsername.SwapTempValue("test") + oldSaslPassword := config.SaslPassword.SwapTempValue("test") + oldkafkaUseSSL := config.KafkaUseSSL.SwapTempValue("true") + oldKafkaTLSKeyPassword := config.KafkaTLSKeyPassword.SwapTempValue("test") + defer func() { + config.SecurityProtocol.SwapTempValue(oldSecurityProtocol) + config.SaslUsername.SwapTempValue(oldSaslUsername) + config.SaslPassword.SwapTempValue(oldSaslPassword) + config.KafkaUseSSL.SwapTempValue(oldkafkaUseSSL) + config.KafkaTLSKeyPassword.SwapTempValue(oldKafkaTLSKeyPassword) + }() + basicConfig := getBasicConfig(config) + + assert.NotNil(t, basicConfig["ssl.key.password"]) + assert.NotNil(t, basicConfig["ssl.certificate.location"]) + assert.NotNil(t, basicConfig["sasl.username"]) + assert.NotNil(t, basicConfig["security.protocol"]) +} diff --git a/pkg/streaming/walimpls/impls/kafka/message_id.go b/pkg/streaming/walimpls/impls/kafka/message_id.go new file mode 100644 index 0000000000..f99ea4ba13 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/message_id.go @@ -0,0 +1,68 @@ +package kafka + +import ( + "strconv" + + "github.com/cockroachdb/errors" + "github.com/confluentinc/confluent-kafka-go/kafka" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" +) + +func UnmarshalMessageID(data string) (message.MessageID, error) { + id, err := unmarshalMessageID(data) + if err != nil { + return nil, err + } + return id, nil +} + +func unmarshalMessageID(data string) (kafkaID, error) { + v, err := message.DecodeUint64(data) + if err != nil { + return 0, errors.Wrapf(message.ErrInvalidMessageID, "decode kafkaID fail with err: %s, id: %s", err.Error(), data) + } + return kafkaID(v), nil +} + +func NewKafkaID(offset kafka.Offset) message.MessageID { + return kafkaID(offset) +} + +type kafkaID kafka.Offset + +// RmqID returns the message id for conversion +// Don't delete this function until conversion logic removed. +// TODO: remove in future. +func (id kafkaID) KafkaID() kafka.Offset { + return kafka.Offset(id) +} + +// WALName returns the name of message id related wal. +func (id kafkaID) WALName() string { + return walName +} + +// LT less than. +func (id kafkaID) LT(other message.MessageID) bool { + return id < other.(kafkaID) +} + +// LTE less than or equal to. +func (id kafkaID) LTE(other message.MessageID) bool { + return id <= other.(kafkaID) +} + +// EQ Equal to. +func (id kafkaID) EQ(other message.MessageID) bool { + return id == other.(kafkaID) +} + +// Marshal marshal the message id. +func (id kafkaID) Marshal() string { + return message.EncodeInt64(int64(id)) +} + +func (id kafkaID) String() string { + return strconv.FormatInt(int64(id), 10) +} diff --git a/pkg/streaming/walimpls/impls/kafka/message_id_test.go b/pkg/streaming/walimpls/impls/kafka/message_id_test.go new file mode 100644 index 0000000000..507a05e1b7 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/message_id_test.go @@ -0,0 +1,32 @@ +package kafka + +import ( + "testing" + + "github.com/confluentinc/confluent-kafka-go/kafka" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" +) + +func TestMessageID(t *testing.T) { + assert.Equal(t, kafka.Offset(1), message.MessageID(kafkaID(1)).(interface{ KafkaID() kafka.Offset }).KafkaID()) + + assert.Equal(t, walName, kafkaID(1).WALName()) + + assert.True(t, kafkaID(1).LT(kafkaID(2))) + assert.True(t, kafkaID(1).EQ(kafkaID(1))) + assert.True(t, kafkaID(1).LTE(kafkaID(1))) + assert.True(t, kafkaID(1).LTE(kafkaID(2))) + assert.False(t, kafkaID(2).LT(kafkaID(1))) + assert.False(t, kafkaID(2).EQ(kafkaID(1))) + assert.False(t, kafkaID(2).LTE(kafkaID(1))) + assert.True(t, kafkaID(2).LTE(kafkaID(2))) + + msgID, err := UnmarshalMessageID(kafkaID(1).Marshal()) + assert.NoError(t, err) + assert.Equal(t, kafkaID(1), msgID) + + _, err = UnmarshalMessageID(string([]byte{0x01, 0x02, 0x03, 0x04})) + assert.Error(t, err) +} diff --git a/pkg/streaming/walimpls/impls/kafka/opener.go b/pkg/streaming/walimpls/impls/kafka/opener.go new file mode 100644 index 0000000000..4f2464c370 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/opener.go @@ -0,0 +1,73 @@ +package kafka + +import ( + "context" + "fmt" + + "github.com/confluentinc/confluent-kafka-go/kafka" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +var _ walimpls.OpenerImpls = (*openerImpl)(nil) + +// newOpenerImpl creates a new openerImpl instance. +func newOpenerImpl(p *kafka.Producer, consumerConfig kafka.ConfigMap) *openerImpl { + o := &openerImpl{ + n: syncutil.NewAsyncTaskNotifier[struct{}](), + p: p, + consumerConfig: consumerConfig, + } + go o.execute() + return o +} + +// openerImpl is the opener implementation for kafka wal. +type openerImpl struct { + n *syncutil.AsyncTaskNotifier[struct{}] + p *kafka.Producer + consumerConfig kafka.ConfigMap +} + +func (o *openerImpl) Open(ctx context.Context, opt *walimpls.OpenOption) (walimpls.WALImpls, error) { + return &walImpl{ + WALHelper: helper.NewWALHelper(opt), + p: o.p, + consumerConfig: o.consumerConfig, + }, nil +} + +func (o *openerImpl) execute() { + defer o.n.Finish(struct{}{}) + + for { + select { + case <-o.n.Context().Done(): + return + case ev, ok := <-o.p.Events(): + if !ok { + panic("kafka producer events channel should never be closed before the execute observer exit") + } + switch ev := ev.(type) { + case kafka.Error: + log.Error("kafka producer error", zap.Error(ev)) + if ev.IsFatal() { + panic(fmt.Sprintf("kafka producer error is fatal, %s", ev.Error())) + } + default: + // ignore other events + log.Debug("kafka producer incoming non-message, non-error event", zap.String("event", ev.String())) + } + } + } +} + +func (o *openerImpl) Close() { + o.n.Cancel() + o.n.BlockUntilFinish() + o.p.Close() +} diff --git a/pkg/streaming/walimpls/impls/kafka/scanner.go b/pkg/streaming/walimpls/impls/kafka/scanner.go new file mode 100644 index 0000000000..934b00da84 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/scanner.go @@ -0,0 +1,88 @@ +package kafka + +import ( + "time" + + "github.com/confluentinc/confluent-kafka-go/kafka" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" +) + +var _ walimpls.ScannerImpls = (*scannerImpl)(nil) + +// newScanner creates a new scanner. +func newScanner(scannerName string, exclude *kafkaID, consumer *kafka.Consumer) *scannerImpl { + s := &scannerImpl{ + ScannerHelper: helper.NewScannerHelper(scannerName), + consumer: consumer, + msgChannel: make(chan message.ImmutableMessage, 1), + exclude: exclude, + } + go s.executeConsume() + return s +} + +// scannerImpl is the implementation of ScannerImpls for kafka. +type scannerImpl struct { + *helper.ScannerHelper + consumer *kafka.Consumer + msgChannel chan message.ImmutableMessage + exclude *kafkaID +} + +// Chan returns the channel of message. +func (s *scannerImpl) Chan() <-chan message.ImmutableMessage { + return s.msgChannel +} + +// Close the scanner, release the underlying resources. +// Return the error same with `Error` +func (s *scannerImpl) Close() error { + s.consumer.Unassign() + err := s.ScannerHelper.Close() + s.consumer.Close() + return err +} + +func (s *scannerImpl) executeConsume() { + defer close(s.msgChannel) + for { + msg, err := s.consumer.ReadMessage(200 * time.Millisecond) + if err != nil { + if s.Context().Err() != nil { + // context canceled, means the the scanner is closed. + s.Finish(nil) + return + } + if c, ok := err.(kafka.Error); ok && c.Code() == kafka.ErrTimedOut { + continue + } + s.Finish(err) + return + } + messageID := kafkaID(msg.TopicPartition.Offset) + if s.exclude != nil && messageID.EQ(*s.exclude) { + // Skip the message that is exclude for StartAfter semantics. + continue + } + + properties := make(map[string]string, len(msg.Headers)) + for _, header := range msg.Headers { + properties[header.Key] = string(header.Value) + } + + newImmutableMessage := message.NewImmutableMesasge( + messageID, + msg.Value, + properties, + ) + select { + case <-s.Context().Done(): + s.Finish(nil) + return + case s.msgChannel <- newImmutableMessage: + } + } +} diff --git a/pkg/streaming/walimpls/impls/kafka/wal.go b/pkg/streaming/walimpls/impls/kafka/wal.go new file mode 100644 index 0000000000..63d0bcb492 --- /dev/null +++ b/pkg/streaming/walimpls/impls/kafka/wal.go @@ -0,0 +1,105 @@ +package kafka + +import ( + "context" + + "github.com/cockroachdb/errors" + "github.com/confluentinc/confluent-kafka-go/kafka" + + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" +) + +var _ walimpls.WALImpls = (*walImpl)(nil) + +type walImpl struct { + *helper.WALHelper + p *kafka.Producer + consumerConfig kafka.ConfigMap +} + +func (w *walImpl) WALName() string { + return walName +} + +func (w *walImpl) Append(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { + properties := msg.Properties().ToRawMap() + headers := make([]kafka.Header, 0, len(properties)) + for key, value := range properties { + header := kafka.Header{Key: key, Value: []byte(value)} + headers = append(headers, header) + } + ch := make(chan kafka.Event, 1) + topic := w.Channel().Name + + if err := w.p.Produce(&kafka.Message{ + TopicPartition: kafka.TopicPartition{Topic: &topic, Partition: 0}, + Value: msg.Payload(), + Headers: headers, + }, ch); err != nil { + return nil, err + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case event := <-ch: + relatedMsg := event.(*kafka.Message) + if relatedMsg.TopicPartition.Error != nil { + return nil, relatedMsg.TopicPartition.Error + } + return kafkaID(relatedMsg.TopicPartition.Offset), nil + } +} + +func (w *walImpl) Read(ctx context.Context, opt walimpls.ReadOption) (s walimpls.ScannerImpls, err error) { + // The scanner is stateless, so we can create a scanner with an anonymous consumer. + // and there's no commit opeartions. + consumerConfig := cloneKafkaConfig(w.consumerConfig) + consumerConfig.SetKey("group.id", opt.Name) + c, err := kafka.NewConsumer(&consumerConfig) + if err != nil { + return nil, errors.Wrap(err, "failed to create kafka consumer") + } + + topic := w.Channel().Name + seekPosition := kafka.TopicPartition{ + Topic: &topic, + Partition: 0, + } + var exclude *kafkaID + switch t := opt.DeliverPolicy.GetPolicy().(type) { + case *streamingpb.DeliverPolicy_All: + seekPosition.Offset = kafka.OffsetBeginning + case *streamingpb.DeliverPolicy_Latest: + seekPosition.Offset = kafka.OffsetEnd + case *streamingpb.DeliverPolicy_StartFrom: + id, err := unmarshalMessageID(t.StartFrom.GetId()) + if err != nil { + return nil, err + } + seekPosition.Offset = kafka.Offset(id) + case *streamingpb.DeliverPolicy_StartAfter: + id, err := unmarshalMessageID(t.StartAfter.GetId()) + if err != nil { + return nil, err + } + seekPosition.Offset = kafka.Offset(id) + exclude = &id + default: + panic("unknown deliver policy") + } + + if err := c.Assign([]kafka.TopicPartition{seekPosition}); err != nil { + return nil, errors.Wrap(err, "failed to assign kafka consumer") + } + return newScanner(opt.Name, exclude, c), nil +} + +func (w *walImpl) Close() { + // The lifetime control of the producer is delegated to the wal adaptor. + // So we just make resource cleanup here. + // But kafka producer is not topic level, so we don't close it here. +} diff --git a/pkg/streaming/walimpls/impls/pulsar/message_id_test.go b/pkg/streaming/walimpls/impls/pulsar/message_id_test.go index 5635c939c9..e6d31f7716 100644 --- a/pkg/streaming/walimpls/impls/pulsar/message_id_test.go +++ b/pkg/streaming/walimpls/impls/pulsar/message_id_test.go @@ -6,9 +6,18 @@ import ( "github.com/apache/pulsar-client-go/pulsar" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) func TestMessageID(t *testing.T) { + pid := message.MessageID(newMessageIDOfPulsar(1, 2, 3)).(interface{ PulsarID() pulsar.MessageID }).PulsarID() + assert.Equal(t, walName, newMessageIDOfPulsar(1, 2, 3).WALName()) + + assert.Equal(t, int64(1), pid.LedgerID()) + assert.Equal(t, int64(2), pid.EntryID()) + assert.Equal(t, int32(3), pid.BatchIdx()) + ids := []pulsarID{ newMessageIDOfPulsar(0, 0, 0), newMessageIDOfPulsar(0, 0, 1), diff --git a/pkg/streaming/walimpls/impls/rmq/message_id_test.go b/pkg/streaming/walimpls/impls/rmq/message_id_test.go index b757e57ab6..e37bfdf056 100644 --- a/pkg/streaming/walimpls/impls/rmq/message_id_test.go +++ b/pkg/streaming/walimpls/impls/rmq/message_id_test.go @@ -4,9 +4,14 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) func TestMessageID(t *testing.T) { + assert.Equal(t, int64(1), message.MessageID(rmqID(1)).(interface{ RmqID() int64 }).RmqID()) + assert.Equal(t, walName, rmqID(1).WALName()) + assert.True(t, rmqID(1).LT(rmqID(2))) assert.True(t, rmqID(1).EQ(rmqID(1))) assert.True(t, rmqID(1).LTE(rmqID(1))) diff --git a/pkg/util/paramtable/param_item.go b/pkg/util/paramtable/param_item.go index 93be87ce7c..ca70589221 100644 --- a/pkg/util/paramtable/param_item.go +++ b/pkg/util/paramtable/param_item.go @@ -101,12 +101,18 @@ func (pi *ParamItem) getWithRaw() (result, raw string, err error) { // SetTempValue set the value for this ParamItem, // Once value set, ParamItem will use the value instead of underlying config manager. // Usage: should only use for unittest, swap empty string will remove the value. -func (pi *ParamItem) SwapTempValue(s string) *string { +func (pi *ParamItem) SwapTempValue(s string) string { if s == "" { - return pi.tempValue.Swap(nil) + if old := pi.tempValue.Swap(nil); old != nil { + return *old + } + return "" } pi.manager.EvictCachedValue(pi.Key) - return pi.tempValue.Swap(&s) + if old := pi.tempValue.Swap(&s); old != nil { + return *old + } + return "" } func (pi *ParamItem) GetValue() string {