diff --git a/internal/proxy/plan_parser.go b/internal/proxy/plan_parser.go index c2c371a5ce..8877a04767 100644 --- a/internal/proxy/plan_parser.go +++ b/internal/proxy/plan_parser.go @@ -14,6 +14,7 @@ package proxy import ( "fmt" "math" + "strings" ant_ast "github.com/antonmedv/expr/ast" ant_parser "github.com/antonmedv/expr/parser" @@ -256,7 +257,33 @@ func getLogicalOpType(opStr string) planpb.BinaryExpr_BinaryOp { } } +func parseBoolNode(nodeRaw *ant_ast.Node) *ant_ast.BoolNode { + switch node := (*nodeRaw).(type) { + case *ant_ast.IdentifierNode: + val := strings.ToLower(node.Value) + if val == "true" { + return &ant_ast.BoolNode{ + Value: true, + } + } else if val == "false" { + return &ant_ast.BoolNode{ + Value: false, + } + } else { + return nil + } + default: + return nil + } +} + func (context *ParserContext) createCmpExpr(left, right ant_ast.Node, operator string) (*planpb.Expr, error) { + if boolNode := parseBoolNode(&left); boolNode != nil { + left = boolNode + } + if boolNode := parseBoolNode(&right); boolNode != nil { + right = boolNode + } idNodeLeft, leftIDNode := left.(*ant_ast.IdentifierNode) idNodeRight, rightIDNode := right.(*ant_ast.IdentifierNode) @@ -540,11 +567,20 @@ func (context *ParserContext) handleLeafValue(nodeRaw *ant_ast.Node, dataType sc Int64Val: int64(node.Value), }, } + } else if dataType == schemapb.DataType_Bool { + gv = &planpb.GenericValue{ + Val: &planpb.GenericValue_BoolVal{}, + } + if node.Value == 1 { + gv.Val.(*planpb.GenericValue_BoolVal).BoolVal = true + } else { + gv.Val.(*planpb.GenericValue_BoolVal).BoolVal = false + } } else { return nil, fmt.Errorf("type mismatch") } case *ant_ast.BoolNode: - if typeutil.IsFloatingType(dataType) { + if typeutil.IsBoolType(dataType) { gv = &planpb.GenericValue{ Val: &planpb.GenericValue_BoolVal{ BoolVal: node.Value, diff --git a/internal/proxy/plan_parser_test.go b/internal/proxy/plan_parser_test.go index aca3f9a9a0..ed2bc26f2a 100644 --- a/internal/proxy/plan_parser_test.go +++ b/internal/proxy/plan_parser_test.go @@ -15,6 +15,7 @@ import ( "fmt" "testing" + ant_ast "github.com/antonmedv/expr/ast" ant_parser "github.com/antonmedv/expr/parser" "github.com/golang/protobuf/proto" @@ -148,6 +149,11 @@ func TestExprMultiRange_Str(t *testing.T) { "0.1 ** 2 < FloatN < 2 ** 0.1", "0.1 ** 1.1 < FloatN < 3.1 / 4", "4.1 / 3 < FloatN < 0.0 / 5.0", + "BoolN1 == True", + "True == BoolN1", + "BoolN1 == False", + "BoolN1 == 1", + "BoolN1 == 0", } fields := []*schemapb.FieldSchema{ @@ -156,6 +162,7 @@ func TestExprMultiRange_Str(t *testing.T) { {FieldID: 102, Name: "age2", DataType: schemapb.DataType_Int64}, {FieldID: 103, Name: "FloatN", DataType: schemapb.DataType_Float}, {FieldID: 104, Name: "FloatN2", DataType: schemapb.DataType_Float}, + {FieldID: 105, Name: "BoolN1", DataType: schemapb.DataType_Bool}, } schema := &schemapb.CollectionSchema{ @@ -214,3 +221,28 @@ func TestExprFieldCompare_Str(t *testing.T) { println(dbgStr) } } + +func Test_ParseBoolNode(t *testing.T) { + var nodeRaw1, nodeRaw2, nodeRaw3, nodeRaw4 ant_ast.Node + nodeRaw1 = &ant_ast.IdentifierNode{ + Value: "True", + } + boolNode1 := parseBoolNode(&nodeRaw1) + assert.Equal(t, boolNode1.Value, true) + + nodeRaw2 = &ant_ast.IdentifierNode{ + Value: "False", + } + boolNode2 := parseBoolNode(&nodeRaw2) + assert.Equal(t, boolNode2.Value, false) + + nodeRaw3 = &ant_ast.IdentifierNode{ + Value: "abcd", + } + assert.Nil(t, parseBoolNode(&nodeRaw3)) + + nodeRaw4 = &ant_ast.BoolNode{ + Value: true, + } + assert.Nil(t, parseBoolNode(&nodeRaw4)) +} diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go index 5c39d54c85..7830298f9c 100644 --- a/internal/util/typeutil/schema.go +++ b/internal/util/typeutil/schema.go @@ -161,3 +161,12 @@ func IsFloatingType(dataType schemapb.DataType) bool { return false } } + +func IsBoolType(dataType schemapb.DataType) bool { + switch dataType { + case schemapb.DataType_Bool: + return true + default: + return false + } +}