diff --git a/internal/util/exprutil/expr_checker.go b/internal/util/exprutil/expr_checker.go index 17a471b178..99c003382e 100644 --- a/internal/util/exprutil/expr_checker.go +++ b/internal/util/exprutil/expr_checker.go @@ -2,7 +2,6 @@ package exprutil import ( "math" - "strings" "github.com/cockroachdb/errors" "github.com/samber/lo" @@ -269,183 +268,6 @@ func StrRangeOverlap(range1 *StrRange, range2 *StrRange) bool { return leftBound <= rightBound } -/* -principles for range parsing -1. no handling unary expr like 'NOT' -2. no handling 'or' expr, no matter on clusteringKey or not, just terminate all possible prune -3. for any unlogical 'and' expr, we check and terminate upper away -4. no handling Term and Range at the same time -*/ - -func ParseRanges(expr *planpb.Expr, kType KeyType) ([]*PlanRange, bool) { - var res []*PlanRange - matchALL := true - switch expr := expr.GetExpr().(type) { - case *planpb.Expr_BinaryExpr: - res, matchALL = ParseRangesFromBinaryExpr(expr.BinaryExpr, kType) - case *planpb.Expr_UnaryRangeExpr: - res, matchALL = ParseRangesFromUnaryRangeExpr(expr.UnaryRangeExpr, kType) - case *planpb.Expr_TermExpr: - res, matchALL = ParseRangesFromTermExpr(expr.TermExpr, kType) - case *planpb.Expr_UnaryExpr: - res, matchALL = nil, true - // we don't handle NOT operation, just consider as unable_to_parse_range - } - return res, matchALL -} - -func ParseRangesFromBinaryExpr(expr *planpb.BinaryExpr, kType KeyType) ([]*PlanRange, bool) { - if expr.Op == planpb.BinaryExpr_LogicalOr { - return nil, true - } - _, leftIsTerm := expr.GetLeft().GetExpr().(*planpb.Expr_TermExpr) - _, rightIsTerm := expr.GetRight().GetExpr().(*planpb.Expr_TermExpr) - if leftIsTerm || rightIsTerm { - // either of lower or upper is term query like x IN [1,2,3] - // we will terminate the prune process - return nil, true - } - leftRanges, leftALL := ParseRanges(expr.Left, kType) - rightRanges, rightALL := ParseRanges(expr.Right, kType) - if leftALL && rightALL { - return nil, true - } else if leftALL && !rightALL { - return rightRanges, rightALL - } else if rightALL && !leftALL { - return leftRanges, leftALL - } - // only unary ranges or further binary ranges are lower - // calculate the intersection and return the resulting ranges - // it's expected that only single range can be returned from lower and upper child - if len(leftRanges) != 1 || len(rightRanges) != 1 { - return nil, true - } - intersected := Intersect(leftRanges[0], rightRanges[0]) - matchALL := intersected == nil - return []*PlanRange{intersected}, matchALL -} - -func ParseRangesFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, kType KeyType) ([]*PlanRange, bool) { - if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey || - expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey { - switch expr.GetOp() { - case planpb.OpType_Equal: - { - return []*PlanRange{ - { - lower: expr.Value, - upper: expr.Value, - includeLower: true, - includeUpper: true, - }, - }, false - } - case planpb.OpType_GreaterThan: - { - return []*PlanRange{ - { - lower: expr.Value, - upper: nil, - includeLower: false, - includeUpper: false, - }, - }, false - } - case planpb.OpType_GreaterEqual: - { - return []*PlanRange{ - { - lower: expr.Value, - upper: nil, - includeLower: true, - includeUpper: false, - }, - }, false - } - case planpb.OpType_LessThan: - { - return []*PlanRange{ - { - lower: nil, - upper: expr.Value, - includeLower: false, - includeUpper: false, - }, - }, false - } - case planpb.OpType_LessEqual: - { - return []*PlanRange{ - { - lower: nil, - upper: expr.Value, - includeLower: false, - includeUpper: true, - }, - }, false - } - } - } - return nil, true -} - -func ParseRangesFromTermExpr(expr *planpb.TermExpr, kType KeyType) ([]*PlanRange, bool) { - if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey || - expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey { - res := make([]*PlanRange, 0) - for _, value := range expr.GetValues() { - res = append(res, &PlanRange{ - lower: value, - upper: value, - includeLower: true, - includeUpper: true, - }) - } - return res, false - } - return nil, true -} - -var minusInfiniteInt = &planpb.GenericValue{ - Val: &planpb.GenericValue_Int64Val{ - Int64Val: math.MinInt64, - }, -} - -var positiveInfiniteInt = &planpb.GenericValue{ - Val: &planpb.GenericValue_Int64Val{ - Int64Val: math.MaxInt64, - }, -} - -var minStrVal = &planpb.GenericValue{ - Val: &planpb.GenericValue_StringVal{ - StringVal: "", - }, -} - -var maxStrVal = &planpb.GenericValue{} - -func complementPlanRange(pr *PlanRange, dataType schemapb.DataType) *PlanRange { - if dataType == schemapb.DataType_Int64 { - if pr.lower == nil { - pr.lower = minusInfiniteInt - } - if pr.upper == nil { - pr.upper = positiveInfiniteInt - } - } else { - if pr.lower == nil { - pr.lower = minStrVal - } - if pr.upper == nil { - pr.upper = maxStrVal - } - } - - return pr -} - func GetCommonDataType(a *PlanRange, b *PlanRange) schemapb.DataType { var bound *planpb.GenericValue if a.lower != nil { @@ -476,71 +298,6 @@ func GetCommonDataType(a *PlanRange, b *PlanRange) schemapb.DataType { return schemapb.DataType_None } -func Intersect(a *PlanRange, b *PlanRange) *PlanRange { - dataType := GetCommonDataType(a, b) - complementPlanRange(a, dataType) - complementPlanRange(b, dataType) - - // Check if 'a' and 'b' non-overlapping at all - rightBound := minGenericValue(a.upper, b.upper) - leftBound := maxGenericValue(a.lower, b.lower) - if compareGenericValue(leftBound, rightBound) > 0 { - return nil - } - - // Check if 'a' range ends exactly where 'b' range starts - if !a.includeUpper && !b.includeLower && (compareGenericValue(a.upper, b.lower) == 0) { - return nil - } - // Check if 'b' range ends exactly where 'a' range starts - if !b.includeUpper && !a.includeLower && (compareGenericValue(b.upper, a.lower) == 0) { - return nil - } - - return &PlanRange{ - lower: leftBound, - upper: rightBound, - includeLower: a.includeLower || b.includeLower, - includeUpper: a.includeUpper || b.includeUpper, - } -} - -func compareGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) int64 { - if right == nil || left == nil { - return -1 - } - switch left.Val.(type) { - case *planpb.GenericValue_Int64Val: - if left.GetInt64Val() == right.GetInt64Val() { - return 0 - } else if left.GetInt64Val() < right.GetInt64Val() { - return -1 - } else { - return 1 - } - case *planpb.GenericValue_StringVal: - if right.Val == nil { - return -1 - } - return int64(strings.Compare(left.GetStringVal(), right.GetStringVal())) - } - return 0 -} - -func minGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue { - if compareGenericValue(left, right) < 0 { - return left - } - return right -} - -func maxGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue { - if compareGenericValue(left, right) >= 0 { - return left - } - return right -} - func ValidatePartitionKeyIsolation(expr *planpb.Expr) error { foundPartitionKey, err := validatePartitionKeyIsolationFromExpr(expr) if err != nil { diff --git a/internal/util/exprutil/expr_checker_test.go b/internal/util/exprutil/expr_checker_test.go index 12f3c72878..e64259990a 100644 --- a/internal/util/exprutil/expr_checker_test.go +++ b/internal/util/exprutil/expr_checker_test.go @@ -148,140 +148,6 @@ func TestParsePartitionKeys(t *testing.T) { } } -func TestParseIntRanges(t *testing.T) { - prefix := "TestParseRanges" - clusterKeyField := "cluster_key_field" - collectionName := prefix + funcutil.GenRandomStr() - - fieldName2Type := make(map[string]schemapb.DataType) - fieldName2Type["int64_field"] = schemapb.DataType_Int64 - fieldName2Type["varChar_field"] = schemapb.DataType_VarChar - fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector - schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type, - "int64_field", false, 8) - clusterKeyFieldSchema := &schemapb.FieldSchema{ - Name: clusterKeyField, - DataType: schemapb.DataType_Int64, - IsClusteringKey: true, - } - schema.Fields = append(schema.Fields, clusterKeyFieldSchema) - - fieldID := common.StartOfUserFieldID - for _, field := range schema.Fields { - field.FieldID = int64(fieldID) - fieldID++ - } - schemaHelper, err := typeutil.CreateSchemaHelper(schema) - require.NoError(t, err) - // test query plan - { - expr := "cluster_key_field > 50" - queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr, nil) - assert.NoError(t, err) - planExpr, err := ParseExprFromPlan(queryPlan) - assert.NoError(t, err) - parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) - assert.False(t, matchALL) - assert.Equal(t, 1, len(parsedRanges)) - range0 := parsedRanges[0] - assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50)) - assert.Nil(t, range0.upper) - assert.Equal(t, range0.includeLower, false) - assert.Equal(t, range0.includeUpper, false) - } - - // test binary query plan - { - expr := "cluster_key_field > 50 and cluster_key_field <= 100" - queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr, nil) - assert.NoError(t, err) - planExpr, err := ParseExprFromPlan(queryPlan) - assert.NoError(t, err) - parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) - assert.False(t, matchALL) - assert.Equal(t, 1, len(parsedRanges)) - range0 := parsedRanges[0] - assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50)) - assert.Equal(t, false, range0.includeLower) - assert.Equal(t, true, range0.includeUpper) - } - - // test binary query plan - { - expr := "cluster_key_field >= 50 and cluster_key_field < 100" - queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr, nil) - assert.NoError(t, err) - planExpr, err := ParseExprFromPlan(queryPlan) - assert.NoError(t, err) - parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) - assert.False(t, matchALL) - assert.Equal(t, 1, len(parsedRanges)) - range0 := parsedRanges[0] - assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50)) - assert.Equal(t, true, range0.includeLower) - assert.Equal(t, false, range0.includeUpper) - } - - // test binary query plan - { - expr := "cluster_key_field in [100]" - queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr, nil) - assert.NoError(t, err) - planExpr, err := ParseExprFromPlan(queryPlan) - assert.NoError(t, err) - parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) - assert.False(t, matchALL) - assert.Equal(t, 1, len(parsedRanges)) - range0 := parsedRanges[0] - assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(100)) - assert.Equal(t, true, range0.includeLower) - assert.Equal(t, true, range0.includeUpper) - } -} - -func TestParseStrRanges(t *testing.T) { - prefix := "TestParseRanges" - clusterKeyField := "cluster_key_field" - collectionName := prefix + funcutil.GenRandomStr() - - fieldName2Type := make(map[string]schemapb.DataType) - fieldName2Type["int64_field"] = schemapb.DataType_Int64 - fieldName2Type["varChar_field"] = schemapb.DataType_VarChar - fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector - schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type, - "int64_field", false, 8) - clusterKeyFieldSchema := &schemapb.FieldSchema{ - Name: clusterKeyField, - DataType: schemapb.DataType_VarChar, - IsClusteringKey: true, - } - schema.Fields = append(schema.Fields, clusterKeyFieldSchema) - - fieldID := common.StartOfUserFieldID - for _, field := range schema.Fields { - field.FieldID = int64(fieldID) - fieldID++ - } - schemaHelper, err := typeutil.CreateSchemaHelper(schema) - require.NoError(t, err) - // test query plan - { - expr := "cluster_key_field >= \"aaa\"" - queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr, nil) - assert.NoError(t, err) - planExpr, err := ParseExprFromPlan(queryPlan) - assert.NoError(t, err) - parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) - assert.False(t, matchALL) - assert.Equal(t, 1, len(parsedRanges)) - range0 := parsedRanges[0] - assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_StringVal).StringVal, "aaa") - assert.Nil(t, range0.upper) - assert.Equal(t, range0.includeLower, true) - assert.Equal(t, range0.includeUpper, false) - } -} - func TestValidatePartitionKeyIsolation(t *testing.T) { prefix := "TestValidatePartitionKeyIsolation" collectionName := prefix + funcutil.GenRandomStr()