diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index 210ee57c8b..17cf00b555 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -17,13 +17,26 @@ func handleExpr(schema *typeutil.SchemaHelper, exprStr string) interface{} { inputStream := antlr.NewInputStream(exprStr) errorListener := &errorListener{} - parser := getParser(inputStream, errorListener) + + lexer := getLexer(inputStream, errorListener) + if errorListener.err != nil { + return errorListener.err + } + + parser := getParser(lexer, errorListener) + if errorListener.err != nil { + return errorListener.err + } ast := parser.Expr() if errorListener.err != nil { return errorListener.err } + // lexer & parser won't be used by this thread, can be put into pool. + putLexer(lexer) + putParser(parser) + visitor := NewParserVisitor(schema) return ast.Accept(visitor) } diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index 3d04591051..14f0cc1d61 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -1,6 +1,7 @@ package planparserv2 import ( + "sync" "testing" "github.com/milvus-io/milvus/internal/proto/planpb" @@ -488,3 +489,45 @@ func TestCreateSearchPlan_Invalid(t *testing.T) { assert.Error(t, err) }) } + +func Test_handleExpr(t *testing.T) { + schema := newTestSchema() + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + ret1 := handleExpr(schemaHelper, "this is not a normal expression") + err1, ok := ret1.(error) + assert.True(t, ok) + assert.Error(t, err1) +} + +// test if handleExpr is thread-safe. +func Test_handleExpr_17126(t *testing.T) { + schema := newTestSchema() + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + normal := "Int64Field > 0" + abnormal := "this is not a normal expression" + + n := 4 // default parallel in regression. + m := 16 + var wg sync.WaitGroup + for i := 0; i < n*m; i++ { + wg.Add(1) + i := i + go func() { + defer wg.Done() + if i%2 == 0 { + ret := handleExpr(schemaHelper, normal) + _, ok := ret.(error) + assert.False(t, ok) + } else { + ret := handleExpr(schemaHelper, abnormal) + err, ok := ret.(error) + assert.True(t, ok) + assert.Error(t, err) + } + }() + } + wg.Wait() +} diff --git a/internal/parser/planparserv2/pool.go b/internal/parser/planparserv2/pool.go index 0a81106bef..005884b60e 100644 --- a/internal/parser/planparserv2/pool.go +++ b/internal/parser/planparserv2/pool.go @@ -25,26 +25,33 @@ func getLexer(stream *antlr.InputStream, listeners ...antlr.ErrorListener) *antl if !ok { lexer = antlrparser.NewPlanLexer(nil) } - lexer.SetInputStream(stream) for _, listener := range listeners { lexer.AddErrorListener(listener) } - lexerPool.Put(lexer) + lexer.SetInputStream(stream) return lexer } -func getParser(stream *antlr.InputStream, listeners ...antlr.ErrorListener) *antlrparser.PlanParser { - lexer := getLexer(stream, listeners...) +func getParser(lexer *antlrparser.PlanLexer, listeners ...antlr.ErrorListener) *antlrparser.PlanParser { tokenStream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) parser, ok := parserPool.Get().(*antlrparser.PlanParser) if !ok { parser = antlrparser.NewPlanParser(nil) } - parser.SetInputStream(tokenStream) - parser.BuildParseTrees = true for _, listener := range listeners { parser.AddErrorListener(listener) } - parserPool.Put(parser) + parser.BuildParseTrees = true + parser.SetInputStream(tokenStream) return parser } + +func putLexer(lexer *antlrparser.PlanLexer) { + lexer.SetInputStream(nil) + lexerPool.Put(lexer) +} + +func putParser(parser *antlrparser.PlanParser) { + parser.SetInputStream(nil) + parserPool.Put(parser) +} diff --git a/internal/parser/planparserv2/pool_test.go b/internal/parser/planparserv2/pool_test.go index 9e238d8bbc..8c0f5b929b 100644 --- a/internal/parser/planparserv2/pool_test.go +++ b/internal/parser/planparserv2/pool_test.go @@ -25,11 +25,15 @@ func Test_getLexer(t *testing.T) { } func Test_getParser(t *testing.T) { + var lexer *antlrparser.PlanLexer var parser *antlrparser.PlanParser - parser = getParser(genNaiveInputStream(), &errorListener{}) + lexer = getLexer(genNaiveInputStream(), &errorListener{}) + assert.NotNil(t, lexer) + + parser = getParser(lexer, &errorListener{}) assert.NotNil(t, parser) - parser = getParser(genNaiveInputStream(), &errorListener{}) + parser = getParser(lexer, &errorListener{}) assert.NotNil(t, parser) }