From 08fa51d4f453dc8bb1c815704fba505ff32ebb1c Mon Sep 17 00:00:00 2001 From: jaime Date: Sun, 28 Jul 2024 21:50:20 +0800 Subject: [PATCH] fix: memory leak while parsing query plan (#34931) issue: #34930 Signed-off-by: jaime --- .../parser/planparserv2/error_listener.go | 13 +++++-- .../parser/planparserv2/plan_parser_v2.go | 18 +++++----- .../planparserv2/plan_parser_v2_test.go | 34 +++++++++++++++++++ internal/parser/planparserv2/pool.go | 2 ++ internal/parser/planparserv2/pool_test.go | 10 +++--- 5 files changed, 62 insertions(+), 15 deletions(-) diff --git a/internal/parser/planparserv2/error_listener.go b/internal/parser/planparserv2/error_listener.go index da5e3e449f..ae429ad1e5 100644 --- a/internal/parser/planparserv2/error_listener.go +++ b/internal/parser/planparserv2/error_listener.go @@ -7,11 +7,20 @@ import ( "github.com/antlr/antlr4/runtime/Go/antlr" ) -type errorListener struct { +type errorListener interface { + antlr.ErrorListener + Error() error +} + +type errorListenerImpl struct { *antlr.DefaultErrorListener err error } -func (l *errorListener) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { +func (l *errorListenerImpl) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { l.err = fmt.Errorf("line " + strconv.Itoa(line) + ":" + strconv.Itoa(column) + " " + msg) } + +func (l *errorListenerImpl) Error() error { + return l.err +} diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index e8e5e94a59..ac57bdbbc7 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -14,6 +14,10 @@ import ( ) func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} { + return handleExprWithErrorListener(schema, exprStr, &errorListenerImpl{}) +} + +func handleExprWithErrorListener(schema *typeutil.SchemaHelper, exprStr string, errorListener errorListener) interface{} { if isEmptyExpression(exprStr) { return &ExprWithType{ dataType: schemapb.DataType_Bool, @@ -22,21 +26,19 @@ func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} { } inputStream := antlr.NewInputStream(exprStr) - errorListener := &errorListener{} - lexer := getLexer(inputStream, errorListener) - if errorListener.err != nil { - return errorListener.err + if errorListener.Error() != nil { + return errorListener.Error() } parser := getParser(lexer, errorListener) - if errorListener.err != nil { - return errorListener.err + if errorListener.Error() != nil { + return errorListener.Error() } ast := parser.Expr() - if errorListener.err != nil { - return errorListener.err + if errorListener.Error() != nil { + return errorListener.Error() } if parser.GetCurrentToken().GetTokenType() != antlr.TokenEOF { diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index a12320f37a..d0a20057ec 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -4,6 +4,7 @@ import ( "sync" "testing" + "github.com/antlr/antlr4/runtime/Go/antlr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -604,6 +605,39 @@ func TestCreateSearchPlan_Invalid(t *testing.T) { }) } +var listenerCnt int + +type errorListenerTest struct { + antlr.DefaultErrorListener +} + +func (l *errorListenerTest) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { + listenerCnt += 1 +} + +func (l *errorListenerTest) ReportAmbiguity(recognizer antlr.Parser, dfa *antlr.DFA, startIndex, stopIndex int, exact bool, ambigAlts *antlr.BitSet, configs antlr.ATNConfigSet) { + listenerCnt += 1 +} + +func (l *errorListenerTest) Error() error { + return nil +} + +func Test_FixErrorListenerNotRemoved(t *testing.T) { + schema := newTestSchema() + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + normal := "1 < Int32Field < (Int16Field)" + for i := 0; i < 10; i++ { + err := handleExprWithErrorListener(schemaHelper, normal, &errorListenerTest{}) + err1, ok := err.(error) + assert.True(t, ok) + assert.Error(t, err1) + } + assert.True(t, listenerCnt <= 10) +} + func Test_handleExpr(t *testing.T) { schema := newTestSchema() schemaHelper, err := typeutil.CreateSchemaHelper(schema) diff --git a/internal/parser/planparserv2/pool.go b/internal/parser/planparserv2/pool.go index 6ea87ca862..5792af5923 100644 --- a/internal/parser/planparserv2/pool.go +++ b/internal/parser/planparserv2/pool.go @@ -72,11 +72,13 @@ func getParser(lexer *antlrparser.PlanLexer, listeners ...antlr.ErrorListener) * func putLexer(lexer *antlrparser.PlanLexer) { lexer.SetInputStream(nil) + lexer.RemoveErrorListeners() lexerPool.ReturnObject(context.TODO(), lexer) } func putParser(parser *antlrparser.PlanParser) { parser.SetInputStream(nil) + parser.RemoveErrorListeners() parserPool.ReturnObject(context.TODO(), parser) } diff --git a/internal/parser/planparserv2/pool_test.go b/internal/parser/planparserv2/pool_test.go index 0e9de00918..2b39fd8684 100644 --- a/internal/parser/planparserv2/pool_test.go +++ b/internal/parser/planparserv2/pool_test.go @@ -16,10 +16,10 @@ func genNaiveInputStream() *antlr.InputStream { func Test_getLexer(t *testing.T) { var lexer *antlrparser.PlanLexer resetLexerPool() - lexer = getLexer(genNaiveInputStream(), &errorListener{}) + lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{}) assert.NotNil(t, lexer) - lexer = getLexer(genNaiveInputStream(), &errorListener{}) + lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{}) assert.NotNil(t, lexer) pool := getLexerPool() @@ -36,13 +36,13 @@ func Test_getParser(t *testing.T) { var parser *antlrparser.PlanParser resetParserPool() - lexer = getLexer(genNaiveInputStream(), &errorListener{}) + lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{}) assert.NotNil(t, lexer) - parser = getParser(lexer, &errorListener{}) + parser = getParser(lexer, &errorListenerImpl{}) assert.NotNil(t, parser) - parser = getParser(lexer, &errorListener{}) + parser = getParser(lexer, &errorListenerImpl{}) assert.NotNil(t, parser) pool := getParserPool()