diff --git a/internal/util/exprutil/expr_checker.go b/internal/util/exprutil/expr_checker.go index 894f869d7d..17a471b178 100644 --- a/internal/util/exprutil/expr_checker.go +++ b/internal/util/exprutil/expr_checker.go @@ -5,9 +5,11 @@ import ( "strings" "github.com/cockroachdb/errors" + "github.com/samber/lo" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/proto/planpb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) type KeyType int64 @@ -37,92 +39,121 @@ func ParseExprFromPlan(plan *planpb.PlanNode) (*planpb.Expr, error) { return expr, nil } +// ParsePartitionKeysFromBinaryExpr parses BinaryExpr is prunble +// if true, returns candidate key values base on the Logical op type. func ParsePartitionKeysFromBinaryExpr(expr *planpb.BinaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { - leftRes, leftInRange := ParseKeysFromExpr(expr.Left, keyType) - rightRes, rightInRange := ParseKeysFromExpr(expr.Right, keyType) + lCandidates, lPrunable := ParseKeysFromExpr(expr.Left, keyType) + rCandidate, rPrunable := ParseKeysFromExpr(expr.Right, keyType) if expr.Op == planpb.BinaryExpr_LogicalAnd { - // case: partition_key_field in [7, 8] && partition_key > 8 - if len(leftRes)+len(rightRes) > 0 { - leftRes = append(leftRes, rightRes...) - return leftRes, false + switch { + case lPrunable && rPrunable: + // case: partition_key in [7, 8] && partition_key in [8, 9] + // return [7, 8] intersect [8, 9] = [8] + return IntersectKeys(lCandidates, rCandidate), true + case lPrunable && !rPrunable: + return lCandidates, true + case !lPrunable && rPrunable: + return rCandidate, true + case !lPrunable && !rPrunable: + return nil, false } - - // case: other_field > 10 && partition_key_field > 8 - return nil, leftInRange || rightInRange } if expr.Op == planpb.BinaryExpr_LogicalOr { - // case: partition_key_field in [7, 8] or partition_key > 8 - if leftInRange || rightInRange { - return nil, true + if lPrunable && rPrunable { + // case: partition_key in [7, 8] || partition_key in [8, 9] + // return [7, 8] union [8, 9] = [7, 8, 9] + return append(lCandidates, rCandidate...), true } - - // case: partition_key_field in [7, 8] or other_field > 10 - leftRes = append(leftRes, rightRes...) - return leftRes, false + return nil, false } return nil, false } +// ParsePartitionKeysFromUnaryExpr parses UnaryExpr is prunble. +// currently, only "Not" is supported, which means unary expression is always not prunable. func ParsePartitionKeysFromUnaryExpr(expr *planpb.UnaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { - res, partitionInRange := ParseKeysFromExpr(expr.GetChild(), keyType) - if expr.Op == planpb.UnaryExpr_Not { - // case: partition_key_field not in [7, 8] - if len(res) != 0 { - return nil, true - } - - // case: other_field not in [10] - return nil, partitionInRange - } - - // UnaryOp only includes "Not" for now - return res, partitionInRange + return nil, false } +// ParsePartitionKeysFromTermExpr parses TermExpr is prunble. +// it checks if the term expression is a partition key or clustering key. func ParsePartitionKeysFromTermExpr(expr *planpb.TermExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { if keyType == PartitionKey && expr.GetColumnInfo().GetIsPartitionKey() { - return expr.GetValues(), false + return expr.GetValues(), true } else if keyType == ClusteringKey && expr.GetColumnInfo().GetIsClusteringKey() { - return expr.GetValues(), false + return expr.GetValues(), true } return nil, false } -func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { +// ParsePartitionKeysFromUnaryRangeExpr parses UnaryRangeExpr is prunble. +func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, keyType KeyType) (candidate []*planpb.GenericValue, prunable bool) { if expr.GetOp() == planpb.OpType_Equal { if expr.GetColumnInfo().GetIsPartitionKey() && keyType == PartitionKey || expr.GetColumnInfo().GetIsClusteringKey() && keyType == ClusteringKey { - return []*planpb.GenericValue{expr.Value}, false + return []*planpb.GenericValue{expr.Value}, true } } - return nil, true + return nil, false } -func ParseKeysFromExpr(expr *planpb.Expr, keyType KeyType) ([]*planpb.GenericValue, bool) { - var res []*planpb.GenericValue - keyInRange := false +// ParseKeysFromExpr parses keys from the given expression based on the key type. +// If the expression can limit the search scope to specified partitions, return the corresponding key values and a flag indicating whether pruning is possible. +// otherwise, return nil and false indicating that pruning is not possible base on this expression. +func ParseKeysFromExpr(expr *planpb.Expr, keyType KeyType) (candidates []*planpb.GenericValue, prunable bool) { switch expr := expr.GetExpr().(type) { case *planpb.Expr_BinaryExpr: - res, keyInRange = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr, keyType) + candidates, prunable = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr, keyType) case *planpb.Expr_UnaryExpr: - res, keyInRange = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr, keyType) + candidates, prunable = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr, keyType) case *planpb.Expr_TermExpr: - res, keyInRange = ParsePartitionKeysFromTermExpr(expr.TermExpr, keyType) + candidates, prunable = ParsePartitionKeysFromTermExpr(expr.TermExpr, keyType) case *planpb.Expr_UnaryRangeExpr: - res, keyInRange = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr, keyType) + candidates, prunable = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr, keyType) } - return res, keyInRange + return candidates, prunable +} + +func IntersectKeys(l []*planpb.GenericValue, r []*planpb.GenericValue) []*planpb.GenericValue { + if len(l) == 0 || len(r) == 0 { + return nil + } + // all elements shall be in same type + switch l[0].Val.(type) { + case *planpb.GenericValue_Int64Val: + lSet := typeutil.NewSet(lo.Map(l, func(e *planpb.GenericValue, _ int) int64 { return e.GetInt64Val() })...) + rSet := typeutil.NewSet(lo.Map(r, func(e *planpb.GenericValue, _ int) int64 { return e.GetInt64Val() })...) + return lo.Map(lSet.Intersection(rSet).Collect(), func(e int64, _ int) *planpb.GenericValue { + return &planpb.GenericValue{ + Val: &planpb.GenericValue_Int64Val{ + Int64Val: e, + }, + } + }) + case *planpb.GenericValue_StringVal: + lSet := typeutil.NewSet(lo.Map(l, func(e *planpb.GenericValue, _ int) string { return e.GetStringVal() })...) + rSet := typeutil.NewSet(lo.Map(r, func(e *planpb.GenericValue, _ int) string { return e.GetStringVal() })...) + return lo.Map(lSet.Intersection(rSet).Collect(), func(e string, _ int) *planpb.GenericValue { + return &planpb.GenericValue{ + Val: &planpb.GenericValue_StringVal{ + StringVal: e, + }, + } + }) + } + return nil } func ParseKeys(expr *planpb.Expr, kType KeyType) []*planpb.GenericValue { - res, keyInRange := ParseKeysFromExpr(expr, kType) - if keyInRange { + res, prunable := ParseKeysFromExpr(expr, kType) + if !prunable { res = nil } + // TODO return empty result if prunable and candidates lens is 0 return res } diff --git a/internal/util/exprutil/expr_checker_test.go b/internal/util/exprutil/expr_checker_test.go index 2a9a3e5f82..0eaf85f7ae 100644 --- a/internal/util/exprutil/expr_checker_test.go +++ b/internal/util/exprutil/expr_checker_test.go @@ -108,9 +108,9 @@ func TestParsePartitionKeys(t *testing.T) { { name: "binary_expr_or with term and not 2", expr: "partition_key_field in [7, 8] or int64_field not in [10, 20]", - expected: 2, - validPartitionKeys: []int64{7, 8}, - invalidPartitionKeys: []int64{10, 20}, + expected: 0, + validPartitionKeys: []int64{}, + invalidPartitionKeys: []int64{}, }, }