enhance: optimize planparser grammar by consolidating similar rules and adding benchmarks (#47092)

issue #47025

- Consolidate MATCH_ALL/MATCH_ANY into MatchSimple rule and
MATCH_LEAST/MATCH_MOST/MATCH_EXACT into MatchThreshold rule
- Consolidate spatial functions (STEquals, STTouches, etc.) into single
SpatialBinary rule
- Simplify visitor implementation to handle consolidated grammar rules
- Add comprehensive benchmark tests for parser performance
- Add optimization comparison tests to validate changes]

Signed-off-by: xiaofanluan <xiaofan.luan@zilliz.com>
This commit is contained in:
Xiaofan 2026-01-20 10:47:30 +08:00 committed by GitHub
parent 9d055adc4a
commit ec38b905f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 4070 additions and 3340 deletions

View File

@ -20,11 +20,8 @@ expr:
| PHRASEMATCH'('Identifier',' StringLiteral (',' expr)? ')' # PhraseMatch
| RANDOMSAMPLE'(' expr ')' # RandomSample
| ElementFilter'('Identifier',' expr')' # ElementFilter
| MATCH_ALL'(' Identifier ',' expr ')' # MatchAll
| MATCH_ANY'(' Identifier ',' expr ')' # MatchAny
| MATCH_LEAST'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchLeast
| MATCH_MOST'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchMost
| MATCH_EXACT'(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchExact
| op=(MATCH_ALL | MATCH_ANY) '(' Identifier ',' expr ')' # MatchSimple
| op=(MATCH_LEAST | MATCH_MOST | MATCH_EXACT) '(' Identifier ',' expr ',' THRESHOLD ASSIGN IntegerConstant ')' # MatchThreshold
| expr POW expr # Power
| op = (ADD | SUB | BNOT | NOT) expr # Unary
// | '(' typeName ')' expr # Cast
@ -35,13 +32,7 @@ expr:
| (JSONContains | ArrayContains)'('expr',' expr')' # JSONContains
| (JSONContainsAll | ArrayContainsAll)'('expr',' expr')' # JSONContainsAll
| (JSONContainsAny | ArrayContainsAny)'('expr',' expr')' # JSONContainsAny
| STEuqals'('Identifier','StringLiteral')' # STEuqals
| STTouches'('Identifier','StringLiteral')' # STTouches
| STOverlaps'('Identifier','StringLiteral')' # STOverlaps
| STCrosses'('Identifier','StringLiteral')' # STCrosses
| STContains'('Identifier','StringLiteral')' # STContains
| STIntersects'('Identifier','StringLiteral')' # STIntersects
| STWithin'('Identifier','StringLiteral')' # STWithin
| op=(STEuqals | STTouches | STOverlaps | STCrosses | STContains | STIntersects | STWithin) '(' Identifier ',' StringLiteral ')' # SpatialBinary
| STDWithin'('Identifier','StringLiteral',' expr')' # STDWithin
| STIsValid'('Identifier')' # STIsValid
| ArrayLength'('(Identifier | JSONIdentifier)')' # ArrayLength
@ -61,15 +52,6 @@ expr:
textMatchOption:
MINIMUM_SHOULD_MATCH ASSIGN IntegerConstant;
// typeName: ty = (BOOL | INT8 | INT16 | INT32 | INT64 | FLOAT | DOUBLE);
// BOOL: 'bool';
// INT8: 'int8';
// INT16: 'int16';
// INT32: 'int32';
// INT64: 'int64';
// FLOAT: 'float';
// DOUBLE: 'double';
LBRACE: '{';
RBRACE: '}';

View File

@ -0,0 +1,222 @@
package planparserv2
import (
"fmt"
"testing"
"github.com/antlr4-go/antlr/v4"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// Benchmark test cases covering various expression patterns
var benchmarkExprs = []struct {
name string
expr string
}{
// Simple comparisons
{"simple_eq", "Int64Field == 100"},
{"simple_lt", "Int64Field < 100"},
{"simple_ne", "Int64Field != 100"},
// Boolean operations
{"bool_and", "Int64Field > 10 && Int64Field < 100"},
{"bool_or", "Int64Field < 10 || Int64Field > 100"},
{"bool_and_text", "Int64Field > 10 and Int64Field < 100"},
{"bool_or_text", "Int64Field < 10 or Int64Field > 100"},
{"bool_complex", "(Int64Field > 10 && Int64Field < 100) || (FloatField > 1.0 && FloatField < 10.0)"},
// Arithmetic
{"arith_add", "Int64Field + 5 == 100"},
{"arith_mul", "Int64Field * 2 < 200"},
// IN expressions
{"in_small", "Int64Field in [1, 2, 3]"},
{"in_medium", "Int64Field in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]"},
{"not_in", "Int64Field not in [1, 2, 3]"},
// String operations
{"string_eq", `StringField == "hello"`},
{"string_like", `StringField like "hello%"`},
{"string_in", `StringField in ["a", "b", "c"]`},
// Array operations
{"array_contains", "array_contains(ArrayField, 1)"},
{"array_contains_all", "array_contains_all(ArrayField, [1, 2, 3])"},
{"array_contains_any", "array_contains_any(ArrayField, [1, 2, 3])"},
{"array_length", "array_length(ArrayField) == 10"},
// JSON field access
{"json_simple", `JSONField["key"] == 100`},
{"json_nested", `JSONField["a"]["b"] == "value"`},
{"json_contains", `json_contains(JSONField["arr"], 1)`},
// NULL checks
{"is_null", "Int64Field is null"},
{"is_not_null", "Int64Field is not null"},
{"is_null_upper", "Int64Field IS NULL"},
{"is_not_null_upper", "Int64Field IS NOT NULL"},
// Range expressions
{"range_lt_lt", "10 < Int64Field < 100"},
{"range_le_le", "10 <= Int64Field <= 100"},
{"range_gt_gt", "100 > Int64Field > 10"},
// EXISTS
{"exists", `exists JSONField["key"]`},
// Complex mixed expressions
{"complex_1", `Int64Field > 10 && StringField like "test%" && array_length(ArrayField) > 0`},
{"complex_2", `(Int64Field in [1,2,3] || FloatField > 1.5) && StringField != "exclude"`},
{"complex_3", `JSONField["status"] == "active" && Int64Field > 0 && Int64Field is not null`},
}
func getTestSchemaHelper(b *testing.B) *typeutil.SchemaHelper {
schema := newTestSchema(true)
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(b, err)
return schemaHelper
}
// BenchmarkParserOverall tests overall parser performance
func BenchmarkParserOverall(b *testing.B) {
schemaHelper := getTestSchemaHelper(b)
for _, tc := range benchmarkExprs {
b.Run(tc.name, func(b *testing.B) {
// Clear cache to test raw parsing performance
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := ParseExpr(schemaHelper, tc.expr, nil)
if err != nil {
b.Fatalf("failed to parse %s: %v", tc.expr, err)
}
}
})
}
}
// BenchmarkLexerOnly tests lexer performance
func BenchmarkLexerOnly(b *testing.B) {
for _, tc := range benchmarkExprs {
b.Run(tc.name, func(b *testing.B) {
exprNormal := convertHanToASCII(tc.expr)
b.ResetTimer()
for i := 0; i < b.N; i++ {
listener := &errorListenerImpl{}
inputStream := antlr.NewInputStream(exprNormal)
lexer := getLexer(inputStream, listener)
// Consume all tokens
for {
tok := lexer.NextToken()
if tok.GetTokenType() == antlr.TokenEOF {
break
}
}
putLexer(lexer)
}
})
}
}
// BenchmarkParseOnly tests parsing without visitor
func BenchmarkParseOnly(b *testing.B) {
for _, tc := range benchmarkExprs {
b.Run(tc.name, func(b *testing.B) {
exprNormal := convertHanToASCII(tc.expr)
b.ResetTimer()
for i := 0; i < b.N; i++ {
listener := &errorListenerImpl{}
inputStream := antlr.NewInputStream(exprNormal)
lexer := getLexer(inputStream, listener)
parser := getParser(lexer, listener)
_ = parser.Expr()
putLexer(lexer)
putParser(parser)
}
})
}
}
// BenchmarkPoolPerformance tests object pool overhead
func BenchmarkPoolPerformance(b *testing.B) {
b.Run("lexer_pool", func(b *testing.B) {
inputStream := antlr.NewInputStream("Int64Field > 10")
b.ResetTimer()
for i := 0; i < b.N; i++ {
lexer := getLexer(inputStream)
putLexer(lexer)
}
})
b.Run("parser_pool", func(b *testing.B) {
inputStream := antlr.NewInputStream("Int64Field > 10")
lexer := getLexer(inputStream)
b.ResetTimer()
for i := 0; i < b.N; i++ {
parser := getParser(lexer)
putParser(parser)
}
putLexer(lexer)
})
}
// BenchmarkScalability tests performance with increasing expression complexity
func BenchmarkScalability(b *testing.B) {
schemaHelper := getTestSchemaHelper(b)
// Test with increasing number of AND conditions
for _, count := range []int{1, 5, 10, 20} {
b.Run(fmt.Sprintf("and_chain_%d", count), func(b *testing.B) {
expr := "Int64Field > 0"
for i := 1; i < count; i++ {
expr += fmt.Sprintf(" && Int64Field < %d", 1000+i)
}
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := ParseExpr(schemaHelper, expr, nil)
if err != nil {
b.Fatal(err)
}
}
})
}
// Test with increasing IN list size
for _, count := range []int{10, 50, 100, 500} {
b.Run(fmt.Sprintf("in_list_%d", count), func(b *testing.B) {
expr := "Int64Field in ["
for i := 0; i < count; i++ {
if i > 0 {
expr += ","
}
expr += fmt.Sprintf("%d", i)
}
expr += "]"
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := ParseExpr(schemaHelper, expr, nil)
if err != nil {
b.Fatal(err)
}
}
})
}
}

File diff suppressed because one or more lines are too long

View File

@ -7,118 +7,6 @@ type BasePlanVisitor struct {
*antlr.BaseParseTreeVisitor
}
func (v *BasePlanVisitor) VisitString(ctx *StringContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchAny(ctx *MatchAnyContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitFloating(ctx *FloatingContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitIsNotNull(ctx *IsNotNullContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitIdentifier(ctx *IdentifierContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTIntersects(ctx *STIntersectsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitLike(ctx *LikeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitEquality(ctx *EqualityContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBoolean(ctx *BooleanContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitShift(ctx *ShiftContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTimestamptzCompareForward(ctx *TimestamptzCompareForwardContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitEmptyArray(ctx *EmptyArrayContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitPhraseMatch(ctx *PhraseMatchContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchLeast(ctx *MatchLeastContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitArrayLength(ctx *ArrayLengthContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTTouches(ctx *STTouchesContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTerm(ctx *TermContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONContains(ctx *JSONContainsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTWithin(ctx *STWithinContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitRange(ctx *RangeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchAll(ctx *MatchAllContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitXor(ctx *BitXorContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitElementFilter(ctx *ElementFilterContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitAnd(ctx *BitAndContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTOverlaps(ctx *STOverlapsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchMost(ctx *MatchMostContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{} {
return v.VisitChildren(ctx)
}
@ -127,11 +15,19 @@ func (v *BasePlanVisitor) VisitRandomSample(ctx *RandomSampleContext) interface{
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSpatialBinary(ctx *SpatialBinaryContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitParens(ctx *ParensContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchExact(ctx *MatchExactContext) interface{} {
func (v *BasePlanVisitor) VisitString(ctx *StringContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitFloating(ctx *FloatingContext) interface{} {
return v.VisitChildren(ctx)
}
@ -143,10 +39,22 @@ func (v *BasePlanVisitor) VisitLogicalOr(ctx *LogicalOrContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitIsNotNull(ctx *IsNotNullContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMulDivMod(ctx *MulDivModContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitIdentifier(ctx *IdentifierContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitLike(ctx *LikeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitLogicalAnd(ctx *LogicalAndContext) interface{} {
return v.VisitChildren(ctx)
}
@ -155,6 +63,14 @@ func (v *BasePlanVisitor) VisitTemplateVariable(ctx *TemplateVariableContext) in
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitEquality(ctx *EqualityContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBoolean(ctx *BooleanContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{} {
return v.VisitChildren(ctx)
}
@ -163,11 +79,19 @@ func (v *BasePlanVisitor) VisitSTDWithin(ctx *STDWithinContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitShift(ctx *ShiftContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTimestamptzCompareForward(ctx *TimestamptzCompareForwardContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitCall(ctx *CallContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTCrosses(ctx *STCrossesContext) interface{} {
func (v *BasePlanVisitor) VisitReverseRange(ctx *ReverseRangeContext) interface{} {
return v.VisitChildren(ctx)
}
@ -175,19 +99,43 @@ func (v *BasePlanVisitor) VisitBitOr(ctx *BitOrContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitEmptyArray(ctx *EmptyArrayContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitAddSub(ctx *AddSubContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitPhraseMatch(ctx *PhraseMatchContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitRelational(ctx *RelationalContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitArrayLength(ctx *ArrayLengthContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitTextMatch(ctx *TextMatchContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTContains(ctx *STContainsContext) interface{} {
func (v *BasePlanVisitor) VisitTerm(ctx *TermContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitJSONContains(ctx *JSONContainsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitRange(ctx *RangeContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchSimple(ctx *MatchSimpleContext) interface{} {
return v.VisitChildren(ctx)
}
@ -207,11 +155,27 @@ func (v *BasePlanVisitor) VisitJSONContainsAny(ctx *JSONContainsAnyContext) inte
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTIsValid(ctx *STIsValidContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitMatchThreshold(ctx *MatchThresholdContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitXor(ctx *BitXorContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitExists(ctx *ExistsContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitSTEuqals(ctx *STEuqalsContext) interface{} {
func (v *BasePlanVisitor) VisitElementFilter(ctx *ElementFilterContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BasePlanVisitor) VisitBitAnd(ctx *BitAndContext) interface{} {
return v.VisitChildren(ctx)
}

File diff suppressed because it is too large Load Diff

View File

@ -7,101 +7,23 @@ import "github.com/antlr4-go/antlr/v4"
type PlanVisitor interface {
antlr.ParseTreeVisitor
// Visit a parse tree produced by PlanParser#String.
VisitString(ctx *StringContext) interface{}
// Visit a parse tree produced by PlanParser#MatchAny.
VisitMatchAny(ctx *MatchAnyContext) interface{}
// Visit a parse tree produced by PlanParser#Floating.
VisitFloating(ctx *FloatingContext) interface{}
// Visit a parse tree produced by PlanParser#IsNotNull.
VisitIsNotNull(ctx *IsNotNullContext) interface{}
// Visit a parse tree produced by PlanParser#Identifier.
VisitIdentifier(ctx *IdentifierContext) interface{}
// Visit a parse tree produced by PlanParser#STIntersects.
VisitSTIntersects(ctx *STIntersectsContext) interface{}
// Visit a parse tree produced by PlanParser#Like.
VisitLike(ctx *LikeContext) interface{}
// Visit a parse tree produced by PlanParser#Equality.
VisitEquality(ctx *EqualityContext) interface{}
// Visit a parse tree produced by PlanParser#Boolean.
VisitBoolean(ctx *BooleanContext) interface{}
// Visit a parse tree produced by PlanParser#Shift.
VisitShift(ctx *ShiftContext) interface{}
// Visit a parse tree produced by PlanParser#TimestamptzCompareForward.
VisitTimestamptzCompareForward(ctx *TimestamptzCompareForwardContext) interface{}
// Visit a parse tree produced by PlanParser#ReverseRange.
VisitReverseRange(ctx *ReverseRangeContext) interface{}
// Visit a parse tree produced by PlanParser#EmptyArray.
VisitEmptyArray(ctx *EmptyArrayContext) interface{}
// Visit a parse tree produced by PlanParser#PhraseMatch.
VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
// Visit a parse tree produced by PlanParser#MatchLeast.
VisitMatchLeast(ctx *MatchLeastContext) interface{}
// Visit a parse tree produced by PlanParser#ArrayLength.
VisitArrayLength(ctx *ArrayLengthContext) interface{}
// Visit a parse tree produced by PlanParser#STTouches.
VisitSTTouches(ctx *STTouchesContext) interface{}
// Visit a parse tree produced by PlanParser#Term.
VisitTerm(ctx *TermContext) interface{}
// Visit a parse tree produced by PlanParser#JSONContains.
VisitJSONContains(ctx *JSONContainsContext) interface{}
// Visit a parse tree produced by PlanParser#STWithin.
VisitSTWithin(ctx *STWithinContext) interface{}
// Visit a parse tree produced by PlanParser#Range.
VisitRange(ctx *RangeContext) interface{}
// Visit a parse tree produced by PlanParser#MatchAll.
VisitMatchAll(ctx *MatchAllContext) interface{}
// Visit a parse tree produced by PlanParser#STIsValid.
VisitSTIsValid(ctx *STIsValidContext) interface{}
// Visit a parse tree produced by PlanParser#BitXor.
VisitBitXor(ctx *BitXorContext) interface{}
// Visit a parse tree produced by PlanParser#ElementFilter.
VisitElementFilter(ctx *ElementFilterContext) interface{}
// Visit a parse tree produced by PlanParser#BitAnd.
VisitBitAnd(ctx *BitAndContext) interface{}
// Visit a parse tree produced by PlanParser#STOverlaps.
VisitSTOverlaps(ctx *STOverlapsContext) interface{}
// Visit a parse tree produced by PlanParser#MatchMost.
VisitMatchMost(ctx *MatchMostContext) interface{}
// Visit a parse tree produced by PlanParser#JSONIdentifier.
VisitJSONIdentifier(ctx *JSONIdentifierContext) interface{}
// Visit a parse tree produced by PlanParser#RandomSample.
VisitRandomSample(ctx *RandomSampleContext) interface{}
// Visit a parse tree produced by PlanParser#SpatialBinary.
VisitSpatialBinary(ctx *SpatialBinaryContext) interface{}
// Visit a parse tree produced by PlanParser#Parens.
VisitParens(ctx *ParensContext) interface{}
// Visit a parse tree produced by PlanParser#MatchExact.
VisitMatchExact(ctx *MatchExactContext) interface{}
// Visit a parse tree produced by PlanParser#String.
VisitString(ctx *StringContext) interface{}
// Visit a parse tree produced by PlanParser#Floating.
VisitFloating(ctx *FloatingContext) interface{}
// Visit a parse tree produced by PlanParser#JSONContainsAll.
VisitJSONContainsAll(ctx *JSONContainsAllContext) interface{}
@ -109,41 +31,80 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#LogicalOr.
VisitLogicalOr(ctx *LogicalOrContext) interface{}
// Visit a parse tree produced by PlanParser#IsNotNull.
VisitIsNotNull(ctx *IsNotNullContext) interface{}
// Visit a parse tree produced by PlanParser#MulDivMod.
VisitMulDivMod(ctx *MulDivModContext) interface{}
// Visit a parse tree produced by PlanParser#Identifier.
VisitIdentifier(ctx *IdentifierContext) interface{}
// Visit a parse tree produced by PlanParser#Like.
VisitLike(ctx *LikeContext) interface{}
// Visit a parse tree produced by PlanParser#LogicalAnd.
VisitLogicalAnd(ctx *LogicalAndContext) interface{}
// Visit a parse tree produced by PlanParser#TemplateVariable.
VisitTemplateVariable(ctx *TemplateVariableContext) interface{}
// Visit a parse tree produced by PlanParser#Equality.
VisitEquality(ctx *EqualityContext) interface{}
// Visit a parse tree produced by PlanParser#Boolean.
VisitBoolean(ctx *BooleanContext) interface{}
// Visit a parse tree produced by PlanParser#TimestamptzCompareReverse.
VisitTimestamptzCompareReverse(ctx *TimestamptzCompareReverseContext) interface{}
// Visit a parse tree produced by PlanParser#STDWithin.
VisitSTDWithin(ctx *STDWithinContext) interface{}
// Visit a parse tree produced by PlanParser#Shift.
VisitShift(ctx *ShiftContext) interface{}
// Visit a parse tree produced by PlanParser#TimestamptzCompareForward.
VisitTimestamptzCompareForward(ctx *TimestamptzCompareForwardContext) interface{}
// Visit a parse tree produced by PlanParser#Call.
VisitCall(ctx *CallContext) interface{}
// Visit a parse tree produced by PlanParser#STCrosses.
VisitSTCrosses(ctx *STCrossesContext) interface{}
// Visit a parse tree produced by PlanParser#ReverseRange.
VisitReverseRange(ctx *ReverseRangeContext) interface{}
// Visit a parse tree produced by PlanParser#BitOr.
VisitBitOr(ctx *BitOrContext) interface{}
// Visit a parse tree produced by PlanParser#EmptyArray.
VisitEmptyArray(ctx *EmptyArrayContext) interface{}
// Visit a parse tree produced by PlanParser#AddSub.
VisitAddSub(ctx *AddSubContext) interface{}
// Visit a parse tree produced by PlanParser#PhraseMatch.
VisitPhraseMatch(ctx *PhraseMatchContext) interface{}
// Visit a parse tree produced by PlanParser#Relational.
VisitRelational(ctx *RelationalContext) interface{}
// Visit a parse tree produced by PlanParser#ArrayLength.
VisitArrayLength(ctx *ArrayLengthContext) interface{}
// Visit a parse tree produced by PlanParser#TextMatch.
VisitTextMatch(ctx *TextMatchContext) interface{}
// Visit a parse tree produced by PlanParser#STContains.
VisitSTContains(ctx *STContainsContext) interface{}
// Visit a parse tree produced by PlanParser#Term.
VisitTerm(ctx *TermContext) interface{}
// Visit a parse tree produced by PlanParser#JSONContains.
VisitJSONContains(ctx *JSONContainsContext) interface{}
// Visit a parse tree produced by PlanParser#Range.
VisitRange(ctx *RangeContext) interface{}
// Visit a parse tree produced by PlanParser#MatchSimple.
VisitMatchSimple(ctx *MatchSimpleContext) interface{}
// Visit a parse tree produced by PlanParser#Unary.
VisitUnary(ctx *UnaryContext) interface{}
@ -157,11 +118,23 @@ type PlanVisitor interface {
// Visit a parse tree produced by PlanParser#JSONContainsAny.
VisitJSONContainsAny(ctx *JSONContainsAnyContext) interface{}
// Visit a parse tree produced by PlanParser#STIsValid.
VisitSTIsValid(ctx *STIsValidContext) interface{}
// Visit a parse tree produced by PlanParser#MatchThreshold.
VisitMatchThreshold(ctx *MatchThresholdContext) interface{}
// Visit a parse tree produced by PlanParser#BitXor.
VisitBitXor(ctx *BitXorContext) interface{}
// Visit a parse tree produced by PlanParser#Exists.
VisitExists(ctx *ExistsContext) interface{}
// Visit a parse tree produced by PlanParser#STEuqals.
VisitSTEuqals(ctx *STEuqalsContext) interface{}
// Visit a parse tree produced by PlanParser#ElementFilter.
VisitElementFilter(ctx *ElementFilterContext) interface{}
// Visit a parse tree produced by PlanParser#BitAnd.
VisitBitAnd(ctx *BitAndContext) interface{}
// Visit a parse tree produced by PlanParser#IsNull.
VisitIsNull(ctx *IsNullContext) interface{}

View File

@ -0,0 +1,358 @@
package planparserv2
import (
"fmt"
"testing"
"github.com/antlr4-go/antlr/v4"
"github.com/stretchr/testify/require"
antlrparser "github.com/milvus-io/milvus/internal/parser/planparserv2/generated"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// Comprehensive benchmark expressions covering all parser features
var optimizationBenchExprs = []struct {
name string
expr string
}{
// === Basic Operations ===
{"int_eq", "Int64Field == 100"},
{"int_lt", "Int64Field < 100"},
{"int_gt", "Int64Field > 100"},
{"int_le", "Int64Field <= 100"},
{"int_ge", "Int64Field >= 100"},
{"int_ne", "Int64Field != 100"},
{"float_eq", "FloatField == 1.5"},
{"float_lt", "FloatField < 1.5"},
// === String Operations ===
{"string_eq", `StringField == "hello"`},
{"string_ne", `StringField != "world"`},
{"string_like", `StringField like "prefix%"`},
{"string_like_suffix", `StringField like "%suffix"`},
{"string_like_contains", `StringField like "%middle%"`},
// === Boolean Operations (case-insensitive keywords) ===
{"bool_and_symbol", "Int64Field > 10 && Int64Field < 100"},
{"bool_or_symbol", "Int64Field < 10 || Int64Field > 100"},
{"bool_and_keyword", "Int64Field > 10 and Int64Field < 100"},
{"bool_or_keyword", "Int64Field < 10 or Int64Field > 100"},
{"bool_AND_upper", "Int64Field > 10 AND Int64Field < 100"},
{"bool_OR_upper", "Int64Field < 10 OR Int64Field > 100"},
{"bool_not", "not (Int64Field > 100)"},
{"bool_NOT_upper", "NOT (Int64Field > 100)"},
// === IN Operations ===
{"in_3", "Int64Field in [1, 2, 3]"},
{"in_5", "Int64Field in [1, 2, 3, 4, 5]"},
{"in_10", "Int64Field in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]"},
{"not_in", "Int64Field not in [1, 2, 3]"},
{"NOT_IN_upper", "Int64Field NOT IN [1, 2, 3]"},
{"string_in", `StringField in ["a", "b", "c"]`},
// === NULL Checks (case-insensitive) ===
{"is_null_lower", "Int64Field is null"},
{"is_null_upper", "Int64Field IS NULL"},
{"is_not_null_lower", "Int64Field is not null"},
{"is_not_null_upper", "Int64Field IS NOT NULL"},
// === Range Expressions ===
{"range_lt_lt", "10 < Int64Field < 100"},
{"range_le_le", "10 <= Int64Field <= 100"},
{"range_gt_gt", "100 > Int64Field > 10"},
{"range_ge_ge", "100 >= Int64Field >= 10"},
// === Arithmetic ===
{"arith_add", "Int64Field + 5 == 100"},
{"arith_sub", "Int64Field - 5 == 95"},
{"arith_mul", "Int64Field * 2 < 200"},
{"arith_div", "Int64Field / 2 > 25"},
{"arith_mod", "Int64Field % 10 == 0"},
// === Array Operations (case-insensitive) ===
{"array_contains", "array_contains(ArrayField, 1)"},
{"ARRAY_CONTAINS_upper", "ARRAY_CONTAINS(ArrayField, 1)"},
{"array_contains_all", "array_contains_all(ArrayField, [1, 2, 3])"},
{"array_contains_any", "array_contains_any(ArrayField, [1, 2, 3])"},
{"array_length", "array_length(ArrayField) == 10"},
{"ARRAY_LENGTH_upper", "ARRAY_LENGTH(ArrayField) == 10"},
// === JSON Operations (case-insensitive) ===
{"json_access", `JSONField["key"] == 100`},
{"json_nested", `JSONField["a"]["b"] == "value"`},
{"json_contains", `json_contains(JSONField["arr"], 1)`},
{"JSON_CONTAINS_upper", `JSON_CONTAINS(JSONField["arr"], 1)`},
{"json_contains_all", `json_contains_all(JSONField["arr"], [1, 2])`},
{"json_contains_any", `json_contains_any(JSONField["arr"], [1, 2])`},
// === EXISTS (case-insensitive) ===
{"exists_lower", `exists JSONField["key"]`},
{"EXISTS_upper", `EXISTS JSONField["key"]`},
// === LIKE (case-insensitive) ===
{"like_lower", `StringField like "test%"`},
{"LIKE_upper", `StringField LIKE "test%"`},
// === Complex Expressions ===
{"complex_and_chain", "Int64Field > 0 && Int64Field < 100 && FloatField > 1.0"},
{"complex_or_chain", "Int64Field < 0 || Int64Field > 100 || FloatField < 0"},
{"complex_mixed", "(Int64Field > 10 && Int64Field < 100) || (FloatField > 1.0 && FloatField < 10.0)"},
{"complex_nested", "((Int64Field > 10 && Int64Field < 50) || (Int64Field > 60 && Int64Field < 100)) && FloatField > 0"},
{"complex_with_in", `(Int64Field in [1,2,3] || FloatField > 1.5) && StringField != "exclude"`},
{"complex_with_json", `JSONField["status"] == "active" && Int64Field > 0 && Int64Field is not null`},
{"complex_with_array", `Int64Field > 10 && array_length(ArrayField) > 0 && array_contains(ArrayField, 1)`},
{"complex_full", `Int64Field > 10 && StringField like "test%" && array_length(ArrayField) > 0 && JSONField["active"] == true`},
}
func getOptBenchSchemaHelper(b *testing.B) *typeutil.SchemaHelper {
schema := newTestSchema(true)
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(b, err)
return schemaHelper
}
// BenchmarkOptimizationComparison benchmarks full parsing pipeline without cache
func BenchmarkOptimizationComparison(b *testing.B) {
schemaHelper := getOptBenchSchemaHelper(b)
for _, tc := range optimizationBenchExprs {
b.Run(tc.name, func(b *testing.B) {
// Purge cache before each benchmark to measure raw parsing performance
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Purge cache on each iteration to prevent caching
exprCache.Purge()
_, err := ParseExpr(schemaHelper, tc.expr, nil)
if err != nil {
b.Fatalf("failed to parse %s: %v", tc.expr, err)
}
}
})
}
}
// BenchmarkLexerComparison benchmarks only the lexer stage
func BenchmarkLexerComparison(b *testing.B) {
for _, tc := range optimizationBenchExprs {
b.Run(tc.name, func(b *testing.B) {
exprNormal := convertHanToASCII(tc.expr)
b.ResetTimer()
for i := 0; i < b.N; i++ {
listener := &errorListenerImpl{}
inputStream := antlr.NewInputStream(exprNormal)
lexer := getLexer(inputStream, listener)
// Consume all tokens
for {
tok := lexer.NextToken()
if tok.GetTokenType() == antlr.TokenEOF {
break
}
}
putLexer(lexer)
}
})
}
}
// BenchmarkParserComparison benchmarks lexer + parser (without visitor)
func BenchmarkParserComparison(b *testing.B) {
for _, tc := range optimizationBenchExprs {
b.Run(tc.name, func(b *testing.B) {
exprNormal := convertHanToASCII(tc.expr)
b.ResetTimer()
for i := 0; i < b.N; i++ {
listener := &errorListenerImpl{}
inputStream := antlr.NewInputStream(exprNormal)
lexer := getLexer(inputStream, listener)
parser := getParser(lexer, listener)
_ = parser.Expr()
putLexer(lexer)
putParser(parser)
}
})
}
}
// BenchmarkConvertHanToASCII benchmarks the Han character conversion
func BenchmarkConvertHanToASCII(b *testing.B) {
testCases := []struct {
name string
expr string
}{
{"short_ascii", "a == 1"},
{"medium_ascii", "Int64Field > 10 && Int64Field < 100"},
{"long_ascii", `(Int64Field in [1,2,3,4,5] || FloatField > 1.5) && StringField != "exclude" && JSONField["status"] == "active"`},
{"with_han_chars", "字段名 == 100 && 另一个字段 > 50"},
}
for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = convertHanToASCII(tc.expr)
}
})
}
}
// BenchmarkDecodeUnicode benchmarks unicode decoding
func BenchmarkDecodeUnicode(b *testing.B) {
testCases := []struct {
name string
expr string
}{
{"no_unicode", "Int64Field == 100"},
{"with_unicode", `field\u0041\u0042 == 100`},
{"multiple_unicode", `\u0041\u0042\u0043\u0044 == "test"`},
}
for _, tc := range testCases {
b.Run(tc.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = decodeUnicode(tc.expr)
}
})
}
}
// BenchmarkPoolOverhead measures sync.Pool get/put overhead
func BenchmarkPoolOverhead(b *testing.B) {
b.Run("lexer_get_put", func(b *testing.B) {
inputStream := antlr.NewInputStream("a == 1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
lexer := getLexer(inputStream)
putLexer(lexer)
}
})
b.Run("parser_get_put", func(b *testing.B) {
inputStream := antlr.NewInputStream("a == 1")
lexer := antlrparser.NewPlanLexer(inputStream)
b.ResetTimer()
for i := 0; i < b.N; i++ {
parser := getParser(lexer)
putParser(parser)
}
})
}
// BenchmarkCacheHitRatio compares cached vs uncached performance
func BenchmarkCacheEffect(b *testing.B) {
schemaHelper := getOptBenchSchemaHelper(b)
expr := "Int64Field > 10 && Int64Field < 100"
b.Run("without_cache", func(b *testing.B) {
for i := 0; i < b.N; i++ {
exprCache.Purge()
_, _ = ParseExpr(schemaHelper, expr, nil)
}
})
b.Run("with_cache", func(b *testing.B) {
// Warm up cache
exprCache.Purge()
_, _ = ParseExpr(schemaHelper, expr, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = ParseExpr(schemaHelper, expr, nil)
}
})
}
// BenchmarkKeywordCaseSensitivity tests case-insensitive keyword performance
func BenchmarkKeywordCaseSensitivity(b *testing.B) {
schemaHelper := getOptBenchSchemaHelper(b)
keywords := []struct {
name string
lower string
upper string
}{
{"and", "a > 1 and a < 10", "a > 1 AND a < 10"},
{"or", "a < 1 or a > 10", "a < 1 OR a > 10"},
{"not", "not (a > 10)", "NOT (a > 10)"},
{"in", "a in [1,2,3]", "a IN [1,2,3]"},
{"like", `b like "test%"`, `b LIKE "test%"`},
{"is_null", "a is null", "a IS NULL"},
{"is_not_null", "a is not null", "a IS NOT NULL"},
{"exists", `exists c["key"]`, `EXISTS c["key"]`},
{"array_contains", "array_contains(d, 1)", "ARRAY_CONTAINS(d, 1)"},
{"json_contains", `json_contains(c["arr"], 1)`, `JSON_CONTAINS(c["arr"], 1)`},
}
for _, kw := range keywords {
b.Run(kw.name+"_lower", func(b *testing.B) {
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
exprCache.Purge()
_, _ = ParseExpr(schemaHelper, kw.lower, nil)
}
})
b.Run(kw.name+"_upper", func(b *testing.B) {
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
exprCache.Purge()
_, _ = ParseExpr(schemaHelper, kw.upper, nil)
}
})
}
}
// BenchmarkExpressionComplexity tests how performance scales with expression complexity
func BenchmarkExpressionComplexity(b *testing.B) {
schemaHelper := getOptBenchSchemaHelper(b)
// Generate expressions with increasing complexity
for _, andCount := range []int{1, 2, 4, 8, 16} {
b.Run(fmt.Sprintf("and_chain_%d", andCount), func(b *testing.B) {
expr := "Int64Field > 0"
for i := 1; i < andCount; i++ {
expr += fmt.Sprintf(" && Int64Field < %d", 1000+i)
}
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
exprCache.Purge()
_, _ = ParseExpr(schemaHelper, expr, nil)
}
})
}
// Test with increasing IN list size
for _, inCount := range []int{5, 10, 25, 50, 100} {
b.Run(fmt.Sprintf("in_list_%d", inCount), func(b *testing.B) {
expr := "Int64Field in ["
for i := 0; i < inCount; i++ {
if i > 0 {
expr += ","
}
expr += fmt.Sprintf("%d", i)
}
expr += "]"
exprCache.Purge()
b.ResetTimer()
for i := 0; i < b.N; i++ {
exprCache.Purge()
_, _ = ParseExpr(schemaHelper, expr, nil)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -613,7 +613,7 @@ func TestExpr_PhraseMatch(t *testing.T) {
}
for i, exprStr := range unsupported {
_, err := ParseExpr(helper, exprStr, nil)
assert.True(t, strings.HasSuffix(err.Error(), errMsgs[i]), fmt.Sprintf("Error expected: %v, actual %v", errMsgs[i], err.Error()))
assert.True(t, strings.Contains(err.Error(), errMsgs[i]), fmt.Sprintf("Error expected: %v, actual %v", errMsgs[i], err.Error()))
}
}
@ -2716,3 +2716,392 @@ func TestExpr_Match(t *testing.T) {
assertInvalidExpr(t, helper, expr)
}
}
// ============================================================================
// Timestamptz Expression Tests
// These tests cover VisitTimestamptzCompareForward and VisitTimestamptzCompareReverse
// which are used for optimized timestamptz comparisons with optional INTERVAL arithmetic
// ============================================================================
func newTestSchemaWithTimestamptz(t *testing.T) *typeutil.SchemaHelper {
// Create schema with Timestamptz field for testing
// The newTestSchema already includes all DataType values including Timestamptz
schema := newTestSchema(true)
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
require.NoError(t, err)
return schemaHelper
}
func TestExpr_TimestamptzCompareForward(t *testing.T) {
schema := newTestSchemaWithTimestamptz(t)
// Test valid timestamptz forward comparisons (column op ISO value)
// Format: TimestamptzField [+|- INTERVAL 'duration'] <op> ISO 'timestamp'
// Note: ISO keyword is required before the timestamp string literal
validExprs := []string{
// Simple comparisons without INTERVAL (quick path)
`TimestamptzField > ISO '2025-01-01T00:00:00Z'`,
`TimestamptzField >= ISO '2025-01-01T00:00:00Z'`,
`TimestamptzField < ISO '2025-12-31T23:59:59Z'`,
`TimestamptzField <= ISO '2025-06-15T12:00:00Z'`,
`TimestamptzField == ISO '2025-03-20T10:30:00Z'`,
`TimestamptzField != ISO '2025-08-10T08:00:00Z'`,
// Comparisons with INTERVAL (slow path with arithmetic)
`TimestamptzField + INTERVAL 'P1D' > ISO '2025-01-01T00:00:00Z'`,
`TimestamptzField - INTERVAL 'P1D' < ISO '2025-12-31T23:59:59Z'`,
`TimestamptzField + INTERVAL 'PT1H' >= ISO '2025-06-15T12:00:00Z'`,
`TimestamptzField - INTERVAL 'PT30M' <= ISO '2025-03-20T10:30:00Z'`,
`TimestamptzField + INTERVAL 'P1Y' == ISO '2026-01-01T00:00:00Z'`,
`TimestamptzField - INTERVAL 'P6M' != ISO '2024-06-01T00:00:00Z'`,
// Complex INTERVAL durations
`TimestamptzField + INTERVAL 'P1Y2M3D' > ISO '2025-01-01T00:00:00Z'`,
`TimestamptzField + INTERVAL 'PT10H30M15S' < ISO '2025-12-31T23:59:59Z'`,
`TimestamptzField - INTERVAL 'P1Y2M3DT4H5M6S' >= ISO '2024-01-01T00:00:00Z'`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
}
func TestExpr_TimestamptzCompareReverse(t *testing.T) {
schema := newTestSchemaWithTimestamptz(t)
// Test valid timestamptz reverse comparisons (ISO value op column)
// Format: ISO 'timestamp' <op> TimestamptzField [+|- INTERVAL 'duration']
// Note: ISO keyword is required before the timestamp string
// Note: Operator gets reversed internally (e.g., '>' becomes '<')
validExprs := []string{
// Simple reverse comparisons without INTERVAL (quick path)
`ISO '2025-01-01T00:00:00Z' < TimestamptzField`,
`ISO '2025-01-01T00:00:00Z' <= TimestamptzField`,
`ISO '2025-12-31T23:59:59Z' > TimestamptzField`,
`ISO '2025-06-15T12:00:00Z' >= TimestamptzField`,
`ISO '2025-03-20T10:30:00Z' == TimestamptzField`,
`ISO '2025-08-10T08:00:00Z' != TimestamptzField`,
// Reverse comparisons with INTERVAL after field (slow path with arithmetic)
`ISO '2025-01-01T00:00:00Z' < TimestamptzField + INTERVAL 'P1D'`,
`ISO '2025-12-31T23:59:59Z' > TimestamptzField - INTERVAL 'P1D'`,
`ISO '2025-06-15T12:00:00Z' <= TimestamptzField + INTERVAL 'PT1H'`,
`ISO '2025-03-20T10:30:00Z' >= TimestamptzField - INTERVAL 'PT30M'`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
}
func TestExpr_TimestamptzCompareInvalid(t *testing.T) {
schema := newTestSchemaWithTimestamptz(t)
// Test invalid timestamptz expressions
// Note: ISO keyword is required for timestamptz comparisons
invalidExprs := []string{
// Invalid field type for timestamptz operations (non-timestamptz field with INTERVAL)
`Int64Field + INTERVAL 'P1D' > ISO '2025-01-01T00:00:00Z'`,
`VarCharField + INTERVAL 'P1D' < ISO '2025-01-01T00:00:00Z'`,
// Invalid timestamp format with ISO
`TimestamptzField > ISO 'invalid-timestamp'`,
`TimestamptzField < ISO '2025-13-01T00:00:00Z'`, // Invalid month
`TimestamptzField > ISO '2025-01-32T00:00:00Z'`, // Invalid day
// Invalid interval format
`TimestamptzField + INTERVAL 'invalid' > ISO '2025-01-01T00:00:00Z'`,
`TimestamptzField + INTERVAL '1D' > ISO '2025-01-01T00:00:00Z'`, // Missing P prefix
}
for _, expr := range invalidExprs {
assertInvalidExpr(t, schema, expr)
}
}
// ============================================================================
// Power Expression Tests
// These tests cover VisitPower for constant power operations
// ============================================================================
func TestExpr_Power(t *testing.T) {
schema := newTestSchemaHelper(t)
// Test valid power expressions with constants
validExprs := []string{
// Integer powers
`2 ** 3 == 8`,
`3 ** 2 == 9`,
`10 ** 0 == 1`,
// Float powers
`2.0 ** 3.0 == 8.0`,
`4.0 ** 0.5 > 1.0`,
// Negative exponents
`2 ** -1 == 0.5`,
// Used in arithmetic expressions
`Int64Field + (2 ** 3) > 0`,
`Int64Field * (10 ** 2) < 1000`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
// Test invalid power expressions - power requires constant operands
invalidExprs := []string{
// Power with field operands (not allowed)
`Int64Field ** 2 == 100`,
`2 ** Int64Field == 8`,
`Int64Field ** Int64Field == 1`,
}
for _, expr := range invalidExprs {
assertInvalidExpr(t, schema, expr)
}
}
// ============================================================================
// Error Handling Tests
// These tests cover the int64OverflowError type and error handling paths
// ============================================================================
func TestInt64OverflowError(t *testing.T) {
// Test int64OverflowError.Error() method - covers the Error() method at 0% coverage
err := &int64OverflowError{literal: "9223372036854775808"}
assert.Contains(t, err.Error(), "int64 overflow")
assert.Contains(t, err.Error(), "9223372036854775808")
// Test isInt64OverflowError helper function
assert.True(t, isInt64OverflowError(err))
assert.False(t, isInt64OverflowError(fmt.Errorf("some other error")))
assert.False(t, isInt64OverflowError(nil))
}
// ============================================================================
// reverseCompareOp Tests
// This function is used internally to reverse comparison operators
// ============================================================================
func Test_reverseCompareOp(t *testing.T) {
// Test all comparison operator reversals
// This covers the reverseCompareOp function at 0% coverage
tests := []struct {
input planpb.OpType
expected planpb.OpType
}{
{planpb.OpType_LessThan, planpb.OpType_GreaterThan},
{planpb.OpType_LessEqual, planpb.OpType_GreaterEqual},
{planpb.OpType_GreaterThan, planpb.OpType_LessThan},
{planpb.OpType_GreaterEqual, planpb.OpType_LessEqual},
{planpb.OpType_Equal, planpb.OpType_Equal},
{planpb.OpType_NotEqual, planpb.OpType_NotEqual},
{planpb.OpType_Invalid, planpb.OpType_Invalid},
{planpb.OpType_PrefixMatch, planpb.OpType_Invalid}, // Unknown ops return Invalid
}
for _, tt := range tests {
result := reverseCompareOp(tt.input)
assert.Equal(t, tt.expected, result, "reverseCompareOp(%v)", tt.input)
}
}
// ============================================================================
// Additional Coverage Tests for Edge Cases
// ============================================================================
func TestExpr_AdditionalEdgeCases(t *testing.T) {
schema := newTestSchemaHelper(t)
// Test valid edge case expressions
validExprs := []string{
// Floating point edge cases
`FloatField > 1e10`,
`DoubleField < 1e-10`,
`FloatField == 3.14159265358979`,
// Boolean expressions
`true == true`,
`false != true`,
// Empty string comparison
`StringField == ""`,
`VarCharField != ""`,
// JSON with complex nested paths
`JSONField["level1"]["level2"]["level3"] > 0`,
// Array length operations
`array_length(ArrayField) > 0`,
`array_length(ArrayField) == 10`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
}
func TestExpr_InvalidOperatorCombinations(t *testing.T) {
schema := newTestSchemaHelper(t)
// Test invalid operator combinations that should fail
// These test the error paths in various Visit methods
invalidExprs := []string{
// Shift operations not supported
`Int64Field << 2`,
`Int64Field >> 2`,
// Bitwise operations not supported
`Int64Field & 0xFF`,
`Int64Field | 0xFF`,
`Int64Field ^ 0xFF`,
// Type mismatches
`"string" + 1`,
`BoolField + 1`,
}
for _, expr := range invalidExprs {
assertInvalidExpr(t, schema, expr)
}
}
// TestExpr_VisitBooleanEdgeCases tests edge cases in VisitBoolean
// Boolean literals must be used in comparison expressions, not as standalone values
func TestExpr_VisitBooleanEdgeCases(t *testing.T) {
schema := newTestSchemaHelper(t)
// Valid boolean comparison expressions
// Note: Standalone boolean values or fields are not valid filter expressions
// They must be used in comparisons
validExprs := []string{
`true == true`,
`false == false`,
`true != false`,
`BoolField == true`,
`BoolField != false`,
`BoolField == BoolField`,
`not (BoolField == true)`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
}
// TestExpr_VisitFloatingEdgeCases tests edge cases in VisitFloating
func TestExpr_VisitFloatingEdgeCases(t *testing.T) {
schema := newTestSchemaHelper(t)
// Valid floating point literal expressions
validExprs := []string{
`FloatField > 0.0`,
`FloatField < 1.0e10`,
`FloatField >= -1.0e-10`,
`FloatField <= 3.14159265`,
`DoubleField == 2.718281828`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
}
// TestExpr_VisitRangeEdgeCases tests edge cases in VisitRange and VisitReverseRange
func TestExpr_VisitRangeEdgeCases(t *testing.T) {
schema := newTestSchemaHelper(t)
// Valid range expressions
validExprs := []string{
// Forward range: lower < field < upper
`1 < Int64Field < 10`,
`0.0 < FloatField < 1.0`,
`"a" < StringField < "z"`,
// Forward range with equal
`1 <= Int64Field < 10`,
`1 < Int64Field <= 10`,
`1 <= Int64Field <= 10`,
// Reverse range: upper > field > lower
`10 > Int64Field > 1`,
`1.0 > FloatField > 0.0`,
`"z" > StringField > "a"`,
// Reverse range with equal
`10 >= Int64Field > 1`,
`10 > Int64Field >= 1`,
`10 >= Int64Field >= 1`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
// Invalid range expressions
invalidExprs := []string{
// Range on bool type is invalid
`true < BoolField < false`,
// Non-const bounds
`Int64Field < Int32Field < Int64Field`,
}
for _, expr := range invalidExprs {
assertInvalidExpr(t, schema, expr)
}
}
// TestExpr_VisitUnaryEdgeCases tests edge cases in VisitUnary
// Unary operators (not/!) must produce boolean expressions for filter predicates
func TestExpr_VisitUnaryEdgeCases(t *testing.T) {
schema := newTestSchemaHelper(t)
// Valid unary expressions - must produce boolean filter predicates
validExprs := []string{
`not (Int64Field > 0)`,
`!(Int64Field < 10)`,
`not (BoolField == true)`,
`not (true == false)`,
`!(FloatField >= 1.0)`,
// Unary negation used in comparison context
`Int64Field > -1`,
`Int64Field < -(-5)`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
}
// TestExpr_ConstantFolding tests constant folding in arithmetic expressions
func TestExpr_ConstantFolding(t *testing.T) {
schema := newTestSchemaHelper(t)
// Expressions where constants can be folded
validExprs := []string{
// Add/Sub constant folding
`Int64Field > (1 + 2)`,
`Int64Field < (10 - 5)`,
`Int64Field == (1 + 2 + 3)`,
// Mul/Div/Mod constant folding
`Int64Field > (2 * 3)`,
`Int64Field < (10 / 2)`,
`Int64Field == (10 % 3)`,
// Mixed operations
`Int64Field > (2 * 3 + 4)`,
`Int64Field < (10 - 2 * 3)`,
// Float constant folding
`FloatField > (1.0 + 2.0)`,
`FloatField < (10.0 / 2.0)`,
}
for _, expr := range validExprs {
assertValidExpr(t, schema, expr)
}
}

View File

@ -1,53 +1,29 @@
package planparserv2
import (
"context"
"sync"
"github.com/antlr4-go/antlr/v4"
pool "github.com/jolestar/go-commons-pool/v2"
antlrparser "github.com/milvus-io/milvus/internal/parser/planparserv2/generated"
"github.com/milvus-io/milvus/pkg/v2/util/hardware"
)
var (
config = &pool.ObjectPoolConfig{
LIFO: pool.DefaultLIFO,
MaxTotal: hardware.GetCPUNum() * 8,
MaxIdle: hardware.GetCPUNum() * 8,
MinIdle: pool.DefaultMinIdle,
MinEvictableIdleTime: pool.DefaultMinEvictableIdleTime,
SoftMinEvictableIdleTime: pool.DefaultSoftMinEvictableIdleTime,
NumTestsPerEvictionRun: pool.DefaultNumTestsPerEvictionRun,
EvictionPolicyName: pool.DefaultEvictionPolicyName,
EvictionContext: context.Background(),
TestOnCreate: pool.DefaultTestOnCreate,
TestOnBorrow: pool.DefaultTestOnBorrow,
TestOnReturn: pool.DefaultTestOnReturn,
TestWhileIdle: pool.DefaultTestWhileIdle,
TimeBetweenEvictionRuns: pool.DefaultTimeBetweenEvictionRuns,
BlockWhenExhausted: false,
lexerPool = sync.Pool{
New: func() interface{} {
return antlrparser.NewPlanLexer(nil)
},
}
ctx = context.Background()
lexerPoolFactory = pool.NewPooledObjectFactorySimple(
func(context.Context) (interface{}, error) {
return antlrparser.NewPlanLexer(nil), nil
})
lexerPool = pool.NewObjectPool(ctx, lexerPoolFactory, config)
parserPoolFactory = pool.NewPooledObjectFactorySimple(
func(context.Context) (interface{}, error) {
return antlrparser.NewPlanParser(nil), nil
})
parserPool = pool.NewObjectPool(ctx, parserPoolFactory, config)
parserPool = sync.Pool{
New: func() interface{} {
return antlrparser.NewPlanParser(nil)
},
}
)
func getLexer(stream *antlr.InputStream, listeners ...antlr.ErrorListener) *antlrparser.PlanLexer {
cached, _ := lexerPool.BorrowObject(context.Background())
lexer, ok := cached.(*antlrparser.PlanLexer)
if !ok {
lexer = antlrparser.NewPlanLexer(nil)
}
lexer := lexerPool.Get().(*antlrparser.PlanLexer)
for _, listener := range listeners {
lexer.AddErrorListener(listener)
}
@ -57,11 +33,7 @@ func getLexer(stream *antlr.InputStream, listeners ...antlr.ErrorListener) *antl
func getParser(lexer *antlrparser.PlanLexer, listeners ...antlr.ErrorListener) *antlrparser.PlanParser {
tokenStream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
cached, _ := parserPool.BorrowObject(context.Background())
parser, ok := cached.(*antlrparser.PlanParser)
if !ok {
parser = antlrparser.NewPlanParser(nil)
}
parser := parserPool.Get().(*antlrparser.PlanParser)
for _, listener := range listeners {
parser.AddErrorListener(listener)
}
@ -73,39 +45,29 @@ func getParser(lexer *antlrparser.PlanLexer, listeners ...antlr.ErrorListener) *
func putLexer(lexer *antlrparser.PlanLexer) {
lexer.SetInputStream(nil)
lexer.RemoveErrorListeners()
lexerPool.ReturnObject(context.TODO(), lexer)
lexerPool.Put(lexer)
}
func putParser(parser *antlrparser.PlanParser) {
parser.SetInputStream(nil)
parser.RemoveErrorListeners()
parserPool.ReturnObject(context.TODO(), parser)
parserPool.Put(parser)
}
func getLexerPool() *pool.ObjectPool {
return lexerPool
}
// only for test
// resetLexerPool resets the lexer pool (for testing)
func resetLexerPool() {
ctx = context.Background()
lexerPoolFactory = pool.NewPooledObjectFactorySimple(
func(context.Context) (interface{}, error) {
return antlrparser.NewPlanLexer(nil), nil
})
lexerPool = pool.NewObjectPool(ctx, lexerPoolFactory, config)
lexerPool = sync.Pool{
New: func() interface{} {
return antlrparser.NewPlanLexer(nil)
},
}
}
func getParserPool() *pool.ObjectPool {
return parserPool
}
// only for test
// resetParserPool resets the parser pool (for testing)
func resetParserPool() {
ctx = context.Background()
parserPoolFactory = pool.NewPooledObjectFactorySimple(
func(context.Context) (interface{}, error) {
return antlrparser.NewPlanParser(nil), nil
})
parserPool = pool.NewObjectPool(ctx, parserPoolFactory, config)
parserPool = sync.Pool{
New: func() interface{} {
return antlrparser.NewPlanParser(nil)
},
}
}

View File

@ -1,6 +1,7 @@
package planparserv2
import (
"sync"
"testing"
"github.com/antlr4-go/antlr/v4"
@ -16,19 +17,21 @@ func genNaiveInputStream() *antlr.InputStream {
func Test_getLexer(t *testing.T) {
var lexer *antlrparser.PlanLexer
resetLexerPool()
lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer)
lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer)
pool := getLexerPool()
assert.Equal(t, pool.GetNumActive(), 2)
assert.Equal(t, pool.GetNumIdle(), 0)
lexer2 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer2)
// Return lexers to the pool
putLexer(lexer)
assert.Equal(t, pool.GetNumActive(), 1)
assert.Equal(t, pool.GetNumIdle(), 1)
putLexer(lexer2)
// Get from pool again - should reuse
lexer3 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer3)
putLexer(lexer3)
}
func Test_getParser(t *testing.T) {
@ -36,20 +39,244 @@ func Test_getParser(t *testing.T) {
var parser *antlrparser.PlanParser
resetParserPool()
resetLexerPool()
lexer = getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer)
parser = getParser(lexer, &errorListenerImpl{})
assert.NotNil(t, parser)
parser = getParser(lexer, &errorListenerImpl{})
parser2 := getParser(lexer, &errorListenerImpl{})
assert.NotNil(t, parser2)
// Return parsers to the pool
putParser(parser)
putParser(parser2)
// Get from pool again - should reuse
parser3 := getParser(lexer, &errorListenerImpl{})
assert.NotNil(t, parser3)
putParser(parser3)
putLexer(lexer)
}
func Test_poolConcurrency(t *testing.T) {
resetLexerPool()
resetParserPool()
// Test concurrent access
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
lexer := getLexer(genNaiveInputStream(), &errorListenerImpl{})
parser := getParser(lexer, &errorListenerImpl{})
_ = parser.Expr()
putParser(parser)
putLexer(lexer)
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
}
// Test_lexerPoolReuse verifies that lexers are properly reused from pool
// This ensures the pool optimization actually works to reduce allocations
func Test_lexerPoolReuse(t *testing.T) {
resetLexerPool()
// Get a lexer and put it back
lexer1 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer1)
putLexer(lexer1)
// Get another lexer - it should be the same instance from pool
lexer2 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer2)
// The lexer should work correctly after being reused
tokens := antlr.NewCommonTokenStream(lexer2, antlr.TokenDefaultChannel)
tokens.Fill()
// Verify tokens are available by checking the token stream size
assert.Greater(t, tokens.Size(), 0)
putLexer(lexer2)
}
// Test_parserPoolReuse verifies that parsers are properly reused from pool
// This ensures the pool optimization actually works to reduce allocations
func Test_parserPoolReuse(t *testing.T) {
resetLexerPool()
resetParserPool()
// Get a parser and put it back
lexer1 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
parser1 := getParser(lexer1, &errorListenerImpl{})
assert.NotNil(t, parser1)
putParser(parser1)
putLexer(lexer1)
// Get another parser - it should work correctly after being reused
lexer2 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
parser2 := getParser(lexer2, &errorListenerImpl{})
assert.NotNil(t, parser2)
// The parser should correctly parse expressions after reuse
expr := parser2.Expr()
assert.NotNil(t, expr)
putParser(parser2)
putLexer(lexer2)
}
// Test_poolWithMultipleErrorListeners tests that error listeners are properly
// managed when getting/putting lexers and parsers
func Test_poolWithMultipleErrorListeners(t *testing.T) {
resetLexerPool()
resetParserPool()
// Create multiple error listeners
listener1 := &errorListenerImpl{}
listener2 := &errorListenerImpl{}
// Get lexer with multiple listeners
lexer := getLexer(genNaiveInputStream(), listener1, listener2)
assert.NotNil(t, lexer)
// Get parser with multiple listeners
parser := getParser(lexer, listener1, listener2)
assert.NotNil(t, parser)
pool := getParserPool()
assert.Equal(t, pool.GetNumActive(), 2)
assert.Equal(t, pool.GetNumIdle(), 0)
// Return to pool - listeners should be removed
putParser(parser)
putLexer(lexer)
// Get again with different listeners - old listeners should not persist
newListener := &errorListenerImpl{}
lexer2 := getLexer(genNaiveInputStream(), newListener)
parser2 := getParser(lexer2, newListener)
// Should still work correctly
expr := parser2.Expr()
assert.NotNil(t, expr)
putParser(parser2)
putLexer(lexer2)
}
// Test_poolWithVariousExpressions tests pool with different expression types
// This ensures pooled lexers/parsers work correctly across various input patterns
func Test_poolWithVariousExpressions(t *testing.T) {
resetLexerPool()
resetParserPool()
expressions := []string{
"a > 2",
"b < 10 && c > 5",
"name == 'test'",
"x + y > z",
"arr[0] == 1",
"json_field['key'] > 100",
"a in [1, 2, 3]",
"1 < x < 10",
"not (a > b)",
}
for _, expr := range expressions {
stream := antlr.NewInputStream(expr)
lexer := getLexer(stream, &errorListenerImpl{})
parser := getParser(lexer, &errorListenerImpl{})
result := parser.Expr()
assert.NotNil(t, result, "Expression '%s' should parse successfully", expr)
putParser(parser)
putLexer(lexer)
}
}
// Test_poolHighConcurrency tests the pool under high concurrent load
// This ensures thread safety of the pool implementation
func Test_poolHighConcurrency(t *testing.T) {
resetLexerPool()
resetParserPool()
const numGoroutines = 100
const numIterations = 10
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < numIterations; j++ {
stream := antlr.NewInputStream("field > " + string(rune('0'+j)))
lexer := getLexer(stream, &errorListenerImpl{})
parser := getParser(lexer, &errorListenerImpl{})
_ = parser.Expr()
putParser(parser)
putLexer(lexer)
}
}(i)
}
wg.Wait()
}
// Test_resetLexerPool verifies that resetLexerPool creates a fresh pool
func Test_resetLexerPool(t *testing.T) {
// Get a lexer from the current pool
lexer1 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
putLexer(lexer1)
// Reset the pool
resetLexerPool()
// Get a new lexer - should be a fresh one from the new pool
lexer2 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
assert.NotNil(t, lexer2)
putLexer(lexer2)
}
// Test_resetParserPool verifies that resetParserPool creates a fresh pool
func Test_resetParserPool(t *testing.T) {
resetLexerPool()
// Get a parser from the current pool
lexer1 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
parser1 := getParser(lexer1, &errorListenerImpl{})
putParser(parser1)
putLexer(lexer1)
// Reset the pool
resetParserPool()
// Get a new parser - should be a fresh one from the new pool
lexer2 := getLexer(genNaiveInputStream(), &errorListenerImpl{})
parser2 := getParser(lexer2, &errorListenerImpl{})
assert.NotNil(t, parser2)
putParser(parser2)
putLexer(lexer2)
}
// Test_poolParserBuildParseTrees verifies that BuildParseTrees is set correctly
// This is important for the parser to generate the parse tree
func Test_poolParserBuildParseTrees(t *testing.T) {
resetLexerPool()
resetParserPool()
lexer := getLexer(genNaiveInputStream(), &errorListenerImpl{})
parser := getParser(lexer, &errorListenerImpl{})
// BuildParseTrees should be true after getParser
assert.True(t, parser.BuildParseTrees)
putParser(parser)
assert.Equal(t, pool.GetNumActive(), 1)
assert.Equal(t, pool.GetNumIdle(), 1)
putLexer(lexer)
}

View File

@ -26,6 +26,22 @@ const (
RandomScoreFileIdKey = "field_id"
)
// Precompiled regex patterns for performance optimization
var (
unicodeEscapeRegex = regexp.MustCompile(`\\u[0-9a-fA-F]{4}`)
iso8601DurationRegex = regexp.MustCompile(
`^P` + // P at the start
`(?:(\d+)Y)?` + // Years (optional)
`(?:(\d+)M)?` + // Months (optional)
`(?:(\d+)D)?` + // Days (optional)
`(?:T` + // T separator (optional, but required for time parts)
`(?:(\d+)H)?` + // Hours (optional)
`(?:(\d+)M)?` + // Minutes (optional)
`(?:(\d+)S)?` + // Seconds (optional)
`)?$`,
)
)
func IsBool(n *planpb.GenericValue) bool {
switch n.GetVal().(type) {
case *planpb.GenericValue_BoolVal:
@ -781,6 +797,20 @@ func parseJSONValue(value interface{}) (*planpb.GenericValue, schemapb.DataType,
}
func convertHanToASCII(s string) string {
// Fast path: check if any Han characters exist first
// This avoids allocation for the common case (ASCII-only strings)
hasHan := false
for _, r := range s {
if unicode.Is(unicode.Han, r) {
hasHan = true
break
}
}
if !hasHan {
return s // Zero allocation for ASCII-only strings
}
// Slow path: process string with Han characters
var builder strings.Builder
builder.Grow(len(s) * 6)
skipCur := false
@ -811,8 +841,7 @@ func convertHanToASCII(s string) string {
}
func decodeUnicode(input string) string {
re := regexp.MustCompile(`\\u[0-9a-fA-F]{4}`)
return re.ReplaceAllStringFunc(input, func(match string) string {
return unicodeEscapeRegex.ReplaceAllStringFunc(input, func(match string) string {
code, _ := strconv.ParseInt(match[2:], 16, 32)
return string(rune(code))
})
@ -835,17 +864,6 @@ func checkValidPoint(wktStr string) error {
}
func parseISODuration(durationStr string) (*planpb.Interval, error) {
iso8601DurationRegex := regexp.MustCompile(
`^P` + // P at the start
`(?:(\d+)Y)?` + // Years (optional)
`(?:(\d+)M)?` + // Months (optional)
`(?:(\d+)D)?` + // Days (optional)
`(?:T` + // T separator (optional, but required for time parts)
`(?:(\d+)H)?` + // Hours (optional)
`(?:(\d+)M)?` + // Minutes (optional)
`(?:(\d+)S)?` + // Seconds (optional)
`)?$`,
)
matches := iso8601DurationRegex.FindStringSubmatch(durationStr)
if matches == nil {
return nil, fmt.Errorf("invalid ISO 8601 duration: %s", durationStr)

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
)
@ -592,6 +593,722 @@ func Test_handleCompare(t *testing.T) {
})
}
// Test_toValueExpr tests the toValueExpr function which converts GenericValue to ExprWithType
// This tests all type branches including the nil return path for unknown types
func Test_toValueExpr(t *testing.T) {
t.Run("bool value", func(t *testing.T) {
// Test that bool values are correctly converted to Bool DataType
value := NewBool(true)
result := toValueExpr(value)
assert.NotNil(t, result)
assert.Equal(t, schemapb.DataType_Bool, result.dataType)
assert.True(t, result.expr.GetValueExpr().GetValue().GetBoolVal())
})
t.Run("int64 value", func(t *testing.T) {
// Test that int64 values are correctly converted to Int64 DataType
value := NewInt(42)
result := toValueExpr(value)
assert.NotNil(t, result)
assert.Equal(t, schemapb.DataType_Int64, result.dataType)
assert.Equal(t, int64(42), result.expr.GetValueExpr().GetValue().GetInt64Val())
})
t.Run("float value", func(t *testing.T) {
// Test that float values are correctly converted to Double DataType
value := NewFloat(3.14)
result := toValueExpr(value)
assert.NotNil(t, result)
assert.Equal(t, schemapb.DataType_Double, result.dataType)
assert.Equal(t, 3.14, result.expr.GetValueExpr().GetValue().GetFloatVal())
})
t.Run("string value", func(t *testing.T) {
// Test that string values are correctly converted to VarChar DataType
value := NewString("hello")
result := toValueExpr(value)
assert.NotNil(t, result)
assert.Equal(t, schemapb.DataType_VarChar, result.dataType)
assert.Equal(t, "hello", result.expr.GetValueExpr().GetValue().GetStringVal())
})
t.Run("array value", func(t *testing.T) {
// Test that array values are correctly converted to Array DataType
value := &planpb.GenericValue{
Val: &planpb.GenericValue_ArrayVal{
ArrayVal: &planpb.Array{
Array: []*planpb.GenericValue{NewInt(1), NewInt(2)},
ElementType: schemapb.DataType_Int64,
},
},
}
result := toValueExpr(value)
assert.NotNil(t, result)
assert.Equal(t, schemapb.DataType_Array, result.dataType)
})
t.Run("nil/unknown value type returns nil", func(t *testing.T) {
// Test that unknown value types return nil - this covers the default branch
value := &planpb.GenericValue{
Val: nil, // nil Val should trigger default case
}
result := toValueExpr(value)
assert.Nil(t, result)
})
}
// Test_getTargetType tests type inference for binary operations
// This ensures correct type promotion rules are applied
func Test_getTargetType(t *testing.T) {
tests := []struct {
name string
left schemapb.DataType
right schemapb.DataType
expected schemapb.DataType
expectError bool
}{
{
name: "JSON with JSON returns JSON",
left: schemapb.DataType_JSON,
right: schemapb.DataType_JSON,
expected: schemapb.DataType_JSON,
},
{
name: "JSON with Float returns Double",
left: schemapb.DataType_JSON,
right: schemapb.DataType_Float,
expected: schemapb.DataType_Double,
},
{
name: "JSON with Int returns Int64",
left: schemapb.DataType_JSON,
right: schemapb.DataType_Int64,
expected: schemapb.DataType_Int64,
},
{
name: "Geometry with Geometry returns Geometry",
left: schemapb.DataType_Geometry,
right: schemapb.DataType_Geometry,
expected: schemapb.DataType_Geometry,
},
{
name: "Timestamptz with Timestamptz returns Timestamptz",
left: schemapb.DataType_Timestamptz,
right: schemapb.DataType_Timestamptz,
expected: schemapb.DataType_Timestamptz,
},
{
name: "Float with JSON returns Double",
left: schemapb.DataType_Float,
right: schemapb.DataType_JSON,
expected: schemapb.DataType_Double,
},
{
name: "Float with Int returns Double",
left: schemapb.DataType_Float,
right: schemapb.DataType_Int64,
expected: schemapb.DataType_Double,
},
{
name: "Int with Float returns Double",
left: schemapb.DataType_Int64,
right: schemapb.DataType_Float,
expected: schemapb.DataType_Double,
},
{
name: "Int with Int returns Int64",
left: schemapb.DataType_Int64,
right: schemapb.DataType_Int64,
expected: schemapb.DataType_Int64,
},
{
name: "Int with JSON returns Int64",
left: schemapb.DataType_Int64,
right: schemapb.DataType_JSON,
expected: schemapb.DataType_Int64,
},
{
name: "String with Int is incompatible",
left: schemapb.DataType_VarChar,
right: schemapb.DataType_Int64,
expectError: true,
},
{
name: "Bool with Int is incompatible",
left: schemapb.DataType_Bool,
right: schemapb.DataType_Int64,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := getTargetType(tt.left, tt.right)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), "incompatible data type")
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
// Test_reverseOrder tests the reverseOrder function which reverses comparison operators
// This is used when the operands of a comparison are swapped
func Test_reverseOrder(t *testing.T) {
tests := []struct {
name string
input planpb.OpType
expected planpb.OpType
expectError bool
}{
{
name: "LessThan reverses to GreaterThan",
input: planpb.OpType_LessThan,
expected: planpb.OpType_GreaterThan,
},
{
name: "LessEqual reverses to GreaterEqual",
input: planpb.OpType_LessEqual,
expected: planpb.OpType_GreaterEqual,
},
{
name: "GreaterThan reverses to LessThan",
input: planpb.OpType_GreaterThan,
expected: planpb.OpType_LessThan,
},
{
name: "GreaterEqual reverses to LessEqual",
input: planpb.OpType_GreaterEqual,
expected: planpb.OpType_LessEqual,
},
{
name: "Equal stays Equal",
input: planpb.OpType_Equal,
expected: planpb.OpType_Equal,
},
{
name: "NotEqual stays NotEqual",
input: planpb.OpType_NotEqual,
expected: planpb.OpType_NotEqual,
},
{
name: "Invalid op type returns error",
input: planpb.OpType_Invalid,
expectError: true,
},
{
name: "PrefixMatch cannot be reversed",
input: planpb.OpType_PrefixMatch,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := reverseOrder(tt.input)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot reverse order")
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
}
}
// Test_isIntegerColumn tests the isIntegerColumn helper function
// This function checks if a column can be converted to integer type
func Test_isIntegerColumn(t *testing.T) {
tests := []struct {
name string
column *planpb.ColumnInfo
expected bool
}{
{
name: "Int64 column is integer",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_Int64,
},
expected: true,
},
{
name: "Int32 column is integer",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_Int32,
},
expected: true,
},
{
name: "JSON column is integer (can contain integers)",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_JSON,
},
expected: true,
},
{
name: "Array of Int64 is integer",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int64,
},
expected: true,
},
{
name: "Timestamptz is integer",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_Timestamptz,
},
expected: true,
},
{
name: "Float column is not integer",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_Float,
},
expected: false,
},
{
name: "String column is not integer",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_VarChar,
},
expected: false,
},
{
name: "Array of Float is not integer",
column: &planpb.ColumnInfo{
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Float,
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isIntegerColumn(tt.column)
assert.Equal(t, tt.expected, result)
})
}
}
// Test_parseJSONValue tests JSON value parsing for various types
// This covers all branches including nested arrays and error cases
func Test_parseJSONValue(t *testing.T) {
t.Run("parse integer from json.Number", func(t *testing.T) {
// Test parsing integer values from JSON numbers
value, dataType, err := parseJSONValue(json.Number("42"))
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Int64, dataType)
assert.Equal(t, int64(42), value.GetInt64Val())
})
t.Run("parse float from json.Number", func(t *testing.T) {
// Test parsing float values from JSON numbers
value, dataType, err := parseJSONValue(json.Number("3.14"))
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Double, dataType)
assert.Equal(t, 3.14, value.GetFloatVal())
})
t.Run("parse string", func(t *testing.T) {
// Test parsing string values
value, dataType, err := parseJSONValue("hello")
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_String, dataType)
assert.Equal(t, "hello", value.GetStringVal())
})
t.Run("parse bool true", func(t *testing.T) {
// Test parsing boolean true
value, dataType, err := parseJSONValue(true)
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Bool, dataType)
assert.True(t, value.GetBoolVal())
})
t.Run("parse bool false", func(t *testing.T) {
// Test parsing boolean false
value, dataType, err := parseJSONValue(false)
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Bool, dataType)
assert.False(t, value.GetBoolVal())
})
t.Run("parse array of integers", func(t *testing.T) {
// Test parsing arrays with same element types
arr := []interface{}{json.Number("1"), json.Number("2"), json.Number("3")}
value, dataType, err := parseJSONValue(arr)
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Array, dataType)
assert.True(t, value.GetArrayVal().GetSameType())
assert.Equal(t, schemapb.DataType_Int64, value.GetArrayVal().GetElementType())
assert.Len(t, value.GetArrayVal().GetArray(), 3)
})
t.Run("parse array of mixed types", func(t *testing.T) {
// Test parsing arrays with mixed element types - sameType should be false
arr := []interface{}{json.Number("1"), "hello", true}
value, dataType, err := parseJSONValue(arr)
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Array, dataType)
assert.False(t, value.GetArrayVal().GetSameType())
assert.Len(t, value.GetArrayVal().GetArray(), 3)
})
t.Run("parse empty array", func(t *testing.T) {
// Test parsing empty arrays
arr := []interface{}{}
value, dataType, err := parseJSONValue(arr)
assert.NoError(t, err)
assert.Equal(t, schemapb.DataType_Array, dataType)
assert.Len(t, value.GetArrayVal().GetArray(), 0)
})
t.Run("invalid json.Number", func(t *testing.T) {
// Test that invalid numbers return error
_, _, err := parseJSONValue(json.Number("not_a_number"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "couldn't convert it")
})
t.Run("unknown type returns error", func(t *testing.T) {
// Test that unknown types return error
_, _, err := parseJSONValue(struct{}{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "unknown type")
})
t.Run("nested array with invalid element", func(t *testing.T) {
// Test that arrays with invalid elements return error
arr := []interface{}{struct{}{}}
_, _, err := parseJSONValue(arr)
assert.Error(t, err)
})
}
// Test_checkValidPoint tests WKT point validation
// This ensures only valid POINT geometries are accepted
func Test_checkValidPoint(t *testing.T) {
t.Run("valid point", func(t *testing.T) {
// Valid POINT geometry should pass
err := checkValidPoint("POINT(1 2)")
assert.NoError(t, err)
})
t.Run("valid point with decimal", func(t *testing.T) {
// Valid POINT with decimal coordinates should pass
err := checkValidPoint("POINT(1.5 2.5)")
assert.NoError(t, err)
})
t.Run("valid point with negative coordinates", func(t *testing.T) {
// Valid POINT with negative coordinates should pass
err := checkValidPoint("POINT(-1.5 -2.5)")
assert.NoError(t, err)
})
t.Run("invalid WKT syntax", func(t *testing.T) {
// Invalid WKT syntax should return error
err := checkValidPoint("invalid")
assert.Error(t, err)
})
t.Run("empty string", func(t *testing.T) {
// Empty string should return error
err := checkValidPoint("")
assert.Error(t, err)
})
t.Run("point with extra spaces", func(t *testing.T) {
// POINT with extra spaces should pass
err := checkValidPoint("POINT( 1 2 )")
assert.NoError(t, err)
})
}
// Test_convertHanToASCII_FastPath tests the Chinese character to Unicode escape conversion
// This function has a fast path for ASCII-only strings to avoid allocation
func Test_convertHanToASCII_FastPath(t *testing.T) {
t.Run("ASCII only string returns unchanged (fast path)", func(t *testing.T) {
// ASCII-only strings should be returned without modification
// This tests the fast path optimization
input := "hello world 123"
result := convertHanToASCII(input)
assert.Equal(t, input, result)
})
t.Run("Chinese characters are converted", func(t *testing.T) {
// Chinese characters should be converted to Unicode escapes
input := "年份"
result := convertHanToASCII(input)
assert.NotEqual(t, input, result)
assert.Contains(t, result, "\\u")
})
t.Run("mixed ASCII and Chinese", func(t *testing.T) {
// Mixed strings should only convert Chinese characters
input := "field年份"
result := convertHanToASCII(input)
assert.Contains(t, result, "field")
assert.Contains(t, result, "\\u")
})
t.Run("string with escape sequence", func(t *testing.T) {
// Escape sequences should be preserved
input := "\\n"
result := convertHanToASCII(input)
assert.Equal(t, input, result)
})
t.Run("string with invalid escape returns original", func(t *testing.T) {
// Invalid escape sequences trigger early return
input := "\\x"
result := convertHanToASCII(input)
assert.Equal(t, input, result)
})
}
// Test_canArithmetic tests arithmetic operation type compatibility
// This ensures proper type checking for arithmetic expressions
func Test_canArithmetic(t *testing.T) {
tests := []struct {
name string
left schemapb.DataType
leftElement schemapb.DataType
right schemapb.DataType
rightElement schemapb.DataType
reverse bool
expectError bool
}{
{
name: "Int64 with Int64",
left: schemapb.DataType_Int64,
right: schemapb.DataType_Int64,
},
{
name: "Float with Float",
left: schemapb.DataType_Float,
right: schemapb.DataType_Float,
},
{
name: "Float with Int64",
left: schemapb.DataType_Float,
right: schemapb.DataType_Int64,
},
{
name: "JSON with Int64",
left: schemapb.DataType_JSON,
right: schemapb.DataType_Int64,
},
{
name: "VarChar with Int64 is invalid",
left: schemapb.DataType_VarChar,
right: schemapb.DataType_Int64,
expectError: true,
},
{
name: "Bool with Int64 is invalid",
left: schemapb.DataType_Bool,
right: schemapb.DataType_Int64,
expectError: true,
},
{
name: "Array of Int64 with Int64",
left: schemapb.DataType_Array,
leftElement: schemapb.DataType_Int64,
right: schemapb.DataType_Int64,
},
{
name: "reverse flag swaps operands",
left: schemapb.DataType_Int64,
right: schemapb.DataType_Float,
reverse: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := canArithmetic(tt.left, tt.leftElement, tt.right, tt.rightElement, tt.reverse)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// Test_checkValidModArith tests modulo operation validation
// Modulo can only be applied to integer types
func Test_checkValidModArith(t *testing.T) {
t.Run("mod with integers is valid", func(t *testing.T) {
err := checkValidModArith(planpb.ArithOpType_Mod,
schemapb.DataType_Int64, schemapb.DataType_None,
schemapb.DataType_Int64, schemapb.DataType_None)
assert.NoError(t, err)
})
t.Run("mod with float left is invalid", func(t *testing.T) {
err := checkValidModArith(planpb.ArithOpType_Mod,
schemapb.DataType_Float, schemapb.DataType_None,
schemapb.DataType_Int64, schemapb.DataType_None)
assert.Error(t, err)
assert.Contains(t, err.Error(), "modulo can only apply on integer types")
})
t.Run("mod with float right is invalid", func(t *testing.T) {
err := checkValidModArith(planpb.ArithOpType_Mod,
schemapb.DataType_Int64, schemapb.DataType_None,
schemapb.DataType_Float, schemapb.DataType_None)
assert.Error(t, err)
})
t.Run("add operation is always valid", func(t *testing.T) {
// Non-mod operations should not be validated by this function
err := checkValidModArith(planpb.ArithOpType_Add,
schemapb.DataType_Float, schemapb.DataType_None,
schemapb.DataType_Float, schemapb.DataType_None)
assert.NoError(t, err)
})
}
// Test_castRangeValue tests value casting for range operations
// This ensures proper type validation and conversion for range expressions
func Test_castRangeValue(t *testing.T) {
t.Run("string value for string type", func(t *testing.T) {
value := NewString("test")
result, err := castRangeValue(schemapb.DataType_VarChar, value)
assert.NoError(t, err)
assert.Equal(t, "test", result.GetStringVal())
})
t.Run("non-string value for string type fails", func(t *testing.T) {
value := NewInt(42)
_, err := castRangeValue(schemapb.DataType_VarChar, value)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid range operations")
})
t.Run("bool type is invalid for range", func(t *testing.T) {
value := NewBool(true)
_, err := castRangeValue(schemapb.DataType_Bool, value)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid range operations on boolean expr")
})
t.Run("integer value for integer type", func(t *testing.T) {
value := NewInt(42)
result, err := castRangeValue(schemapb.DataType_Int64, value)
assert.NoError(t, err)
assert.Equal(t, int64(42), result.GetInt64Val())
})
t.Run("non-integer value for integer type fails", func(t *testing.T) {
value := NewFloat(3.14)
_, err := castRangeValue(schemapb.DataType_Int64, value)
assert.Error(t, err)
})
t.Run("float value for float type", func(t *testing.T) {
value := NewFloat(3.14)
result, err := castRangeValue(schemapb.DataType_Float, value)
assert.NoError(t, err)
assert.Equal(t, 3.14, result.GetFloatVal())
})
t.Run("integer value promoted to float for float type", func(t *testing.T) {
// Integer values should be promoted to float when target type is float
value := NewInt(42)
result, err := castRangeValue(schemapb.DataType_Double, value)
assert.NoError(t, err)
assert.Equal(t, float64(42), result.GetFloatVal())
})
t.Run("non-number value for float type fails", func(t *testing.T) {
value := NewString("test")
_, err := castRangeValue(schemapb.DataType_Float, value)
assert.Error(t, err)
})
}
// Test_hexDigit tests the hexDigit helper function
// This is used for Unicode escape encoding
func Test_hexDigit(t *testing.T) {
// Test digits 0-9
for i := uint32(0); i < 10; i++ {
result := hexDigit(i)
expected := byte(i) + '0'
assert.Equal(t, expected, result, "hexDigit(%d) should be %c", i, expected)
}
// Test hex digits a-f
for i := uint32(10); i < 16; i++ {
result := hexDigit(i)
expected := byte(i-10) + 'a'
assert.Equal(t, expected, result, "hexDigit(%d) should be %c", i, expected)
}
// Test that only lower 4 bits are used
result := hexDigit(0x1f) // 31 & 0xf = 15 = 'f'
assert.Equal(t, byte('f'), result)
}
// Test_formatUnicode tests Unicode escape formatting
func Test_formatUnicode(t *testing.T) {
// Test basic Chinese character
result := formatUnicode(0x5e74) // '年'
assert.Equal(t, "\\u5e74", result)
// Test ASCII character
result = formatUnicode(0x0041) // 'A'
assert.Equal(t, "\\u0041", result)
}
// Test_isEscapeCh tests escape character detection
func Test_isEscapeCh(t *testing.T) {
escapeChs := []uint8{'\\', 'n', 't', 'r', 'f', '"', '\''}
for _, ch := range escapeChs {
assert.True(t, isEscapeCh(ch), "isEscapeCh(%c) should be true", ch)
}
nonEscapeChs := []uint8{'a', 'b', '1', ' ', 'x'}
for _, ch := range nonEscapeChs {
assert.False(t, isEscapeCh(ch), "isEscapeCh(%c) should be false", ch)
}
}
// Test_isEmptyExpression_Utils tests empty expression detection
func Test_isEmptyExpression_Utils(t *testing.T) {
assert.True(t, isEmptyExpression(""))
assert.True(t, isEmptyExpression(" "))
assert.True(t, isEmptyExpression("\t\n"))
assert.False(t, isEmptyExpression("a > 1"))
assert.False(t, isEmptyExpression(" a > 1 "))
}
// Test_checkValidWKT tests WKT validation
func Test_checkValidWKT(t *testing.T) {
t.Run("valid point", func(t *testing.T) {
err := checkValidWKT("POINT(1 2)")
assert.NoError(t, err)
})
t.Run("valid polygon", func(t *testing.T) {
err := checkValidWKT("POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))")
assert.NoError(t, err)
})
t.Run("invalid WKT", func(t *testing.T) {
err := checkValidWKT("invalid geometry")
assert.Error(t, err)
})
}
func TestParseISO8601Duration(t *testing.T) {
testCases := []struct {
name string