From bd17544fa50ccba44d695ff66a6aead0ecb2a55f Mon Sep 17 00:00:00 2001 From: SimFG Date: Wed, 28 Jun 2023 18:36:26 +0800 Subject: [PATCH] Fix the kafka panic when sending the message to a closed channel (#25116) Signed-off-by: SimFG --- .../mqwrapper/kafka/kafka_producer.go | 14 ++++++++- .../mqwrapper/kafka/kafka_producer_test.go | 30 +++++++++++++++++++ internal/querynode/validate_test.go | 9 ++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/internal/mq/msgstream/mqwrapper/kafka/kafka_producer.go b/internal/mq/msgstream/mqwrapper/kafka/kafka_producer.go index 724a8d6671..ce019131d7 100644 --- a/internal/mq/msgstream/mqwrapper/kafka/kafka_producer.go +++ b/internal/mq/msgstream/mqwrapper/kafka/kafka_producer.go @@ -37,6 +37,7 @@ type kafkaProducer struct { topic string deliveryChan chan kafka.Event closeOnce sync.Once + isClosed bool } func (kp *kafkaProducer) Topic() string { @@ -47,6 +48,12 @@ func (kp *kafkaProducer) Send(ctx context.Context, message *mqwrapper.ProducerMe start := timerecord.NewTimeRecorder("send msg to stream") metrics.MsgStreamOpCounter.WithLabelValues(metrics.SendMsgLabel, metrics.TotalLabel).Inc() + if kp.isClosed { + metrics.MsgStreamOpCounter.WithLabelValues(metrics.SendMsgLabel, metrics.FailLabel).Inc() + log.Error("kafka produce message fail because the producer has been closed", zap.String("topic", kp.topic)) + return nil, common.NewIgnorableError(fmt.Errorf("kafka producer is closed")) + } + err := kp.p.Produce(&kafka.Message{ TopicPartition: kafka.TopicPartition{Topic: &kp.topic, Partition: mqwrapper.DefaultPartitionIdx}, Value: message.Payload, @@ -79,9 +86,14 @@ func (kp *kafkaProducer) Send(ctx context.Context, message *mqwrapper.ProducerMe func (kp *kafkaProducer) Close() { kp.closeOnce.Do(func() { + kp.isClosed = true + start := time.Now() //flush in-flight msg within queue. - kp.p.Flush(10000) + i := kp.p.Flush(10000) + if i > 0 { + log.Warn("There are still un-flushed outstanding events", zap.Int("event_num", i), zap.Any("topic", kp.topic)) + } close(kp.deliveryChan) diff --git a/internal/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go b/internal/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go index 3c1df35cba..602463d991 100644 --- a/internal/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go +++ b/internal/mq/msgstream/mqwrapper/kafka/kafka_producer_test.go @@ -27,6 +27,8 @@ func TestKafkaProducer_SendSuccess(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, producer) + producer.Close() + kafkaProd := producer.(*kafkaProducer) assert.Equal(t, kafkaProd.Topic(), topic) @@ -35,6 +37,7 @@ func TestKafkaProducer_SendSuccess(t *testing.T) { Properties: map[string]string{}, } msgID, err := producer.Send(context.TODO(), msg2) + time.Sleep(30 * time.Second) assert.Nil(t, err) assert.NotNil(t, msgID) @@ -67,3 +70,30 @@ func TestKafkaProducer_SendFail(t *testing.T) { producer.Close() } } + +func TestKafkaProducer_SendFailAfterClose(t *testing.T) { + kafkaAddress := getKafkaBrokerList() + kc := NewKafkaClientInstance(kafkaAddress) + defer kc.Close() + assert.NotNil(t, kc) + + rand.Seed(time.Now().UnixNano()) + topic := fmt.Sprintf("test-topic-%d", rand.Int()) + + producer, err := kc.CreateProducer(mqwrapper.ProducerOptions{Topic: topic}) + assert.Nil(t, err) + assert.NotNil(t, producer) + + producer.Close() + + kafkaProd := producer.(*kafkaProducer) + assert.Equal(t, kafkaProd.Topic(), topic) + + msg2 := &mqwrapper.ProducerMessage{ + Payload: []byte{}, + Properties: map[string]string{}, + } + _, err = producer.Send(context.TODO(), msg2) + time.Sleep(10 * time.Second) + assert.NotNil(t, err) +} diff --git a/internal/querynode/validate_test.go b/internal/querynode/validate_test.go index ea505a6bc4..288b832675 100644 --- a/internal/querynode/validate_test.go +++ b/internal/querynode/validate_test.go @@ -30,6 +30,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test normal validate", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) @@ -37,6 +38,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test normal validate2", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) assert.NoError(t, err) }) @@ -44,6 +46,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test validate non-existent collection", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID+1, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) @@ -51,6 +54,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test validate non-existent partition", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID + 1}, []UniqueID{defaultSegmentID}) assert.Error(t, err) }) @@ -58,6 +62,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test validate non-existent segment", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{defaultPartitionID}, []UniqueID{defaultSegmentID + 1}) assert.NoError(t, err) }) @@ -65,6 +70,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test validate segment not in given partition", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() err = his.addPartition(defaultCollectionID, defaultPartitionID+1) assert.NoError(t, err) schema := genTestCollectionSchema() @@ -86,6 +92,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test validate after partition release", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() err = his.removePartition(defaultPartitionID) assert.NoError(t, err) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) @@ -95,6 +102,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test validate after partition release2", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() col, err := his.getCollectionByID(defaultCollectionID) assert.NoError(t, err) col.setLoadType(loadTypePartition) @@ -107,6 +115,7 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) { t.Run("test validate after partition release3", func(t *testing.T) { his, err := genSimpleReplicaWithSealSegment(ctx) assert.NoError(t, err) + defer his.freeAll() col, err := his.getCollectionByID(defaultCollectionID) assert.NoError(t, err) col.setLoadType(loadTypeCollection)