milvus/internal/parser/planparserv2/fill_expression_value.go
Spade A f6f716bcfd
feat: impl StructArray -- support embedding searches embeddings in embedding list with element level filter expression (#45830)
issue: https://github.com/milvus-io/milvus/issues/42148

For a vector field inside a STRUCT, since a STRUCT can only appear as
the element type of an ARRAY field, the vector field in STRUCT is
effectively an array of vectors, i.e. an embedding list.
Milvus already supports searching embedding lists with metrics whose
names start with the prefix MAX_SIM_.

This PR allows Milvus to search embeddings inside an embedding list
using the same metrics as normal embedding fields. Each embedding in the
list is treated as an independent vector and participates in ANN search.

Further, since STRUCT may contain scalar fields that are highly related
to the embedding field, this PR introduces an element-level filter
expression to refine search results.
The grammar of the element-level filter is:

element_filter(structFieldName, $[subFieldName] == 3)

where $[subFieldName] refers to the value of subFieldName in each
element of the STRUCT array structFieldName.

It can be combined with existing filter expressions, for example:

"varcharField == 'aaa' && element_filter(struct_field, $[struct_int] ==
3)"

A full example:
```
struct_schema = milvus_client.create_struct_field_schema()
struct_schema.add_field("struct_str", DataType.VARCHAR, max_length=65535)
struct_schema.add_field("struct_int", DataType.INT32)
struct_schema.add_field("struct_float_vec", DataType.FLOAT_VECTOR, dim=EMBEDDING_DIM)

schema.add_field(
    "struct_field",
    datatype=DataType.ARRAY,
    element_type=DataType.STRUCT,
    struct_schema=struct_schema,
    max_capacity=1000,
)
...

filter = "varcharField == 'aaa' && element_filter(struct_field, $[struct_int] == 3 && $[struct_str] == 'abc')"
res = milvus_client.search(
    COLLECTION_NAME,
    data=query_embeddings,
    limit=10,
    anns_field="struct_field[struct_float_vec]",
    filter=filter,
    output_fields=["struct_field[struct_int]", "varcharField"],
)

```
TODO:
1. When an `element_filter` expression is used, a regular filter
expression must also be present. Remove this restriction.
2. Implement `element_filter` expressions in the `query`.

---------

Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
2025-12-15 12:01:15 +08:00

244 lines
8.0 KiB
Go

package planparserv2
import (
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func FillExpressionValue(expr *planpb.Expr, templateValues map[string]*planpb.GenericValue) error {
if !expr.GetIsTemplate() {
return nil
}
switch e := expr.GetExpr().(type) {
case *planpb.Expr_TermExpr:
return FillTermExpressionValue(e.TermExpr, templateValues)
case *planpb.Expr_UnaryExpr:
return FillExpressionValue(e.UnaryExpr.GetChild(), templateValues)
case *planpb.Expr_BinaryExpr:
if err := FillExpressionValue(e.BinaryExpr.GetLeft(), templateValues); err != nil {
return err
}
return FillExpressionValue(e.BinaryExpr.GetRight(), templateValues)
case *planpb.Expr_UnaryRangeExpr:
return FillUnaryRangeExpressionValue(e.UnaryRangeExpr, templateValues)
case *planpb.Expr_BinaryRangeExpr:
return FillBinaryRangeExpressionValue(e.BinaryRangeExpr, templateValues)
case *planpb.Expr_BinaryArithOpEvalRangeExpr:
return FillBinaryArithOpEvalRangeExpressionValue(e.BinaryArithOpEvalRangeExpr, templateValues)
case *planpb.Expr_BinaryArithExpr:
if err := FillExpressionValue(e.BinaryArithExpr.GetLeft(), templateValues); err != nil {
return err
}
return FillExpressionValue(e.BinaryArithExpr.GetRight(), templateValues)
case *planpb.Expr_JsonContainsExpr:
return FillJSONContainsExpressionValue(e.JsonContainsExpr, templateValues)
case *planpb.Expr_RandomSampleExpr:
return FillExpressionValue(expr.GetExpr().(*planpb.Expr_RandomSampleExpr).RandomSampleExpr.GetPredicate(), templateValues)
case *planpb.Expr_ElementFilterExpr:
if err := FillExpressionValue(e.ElementFilterExpr.GetElementExpr(), templateValues); err != nil {
return err
}
if e.ElementFilterExpr.GetPredicate() != nil {
return FillExpressionValue(e.ElementFilterExpr.GetPredicate(), templateValues)
}
return nil
default:
return fmt.Errorf("this expression no need to fill placeholder with expr type: %T", e)
}
}
func FillTermExpressionValue(expr *planpb.TermExpr, templateValues map[string]*planpb.GenericValue) error {
value, ok := templateValues[expr.GetTemplateVariableName()]
if !ok && expr.GetValues() == nil {
return fmt.Errorf("the value of expression template variable name {%s} is not found", expr.GetTemplateVariableName())
}
if value == nil || value.GetArrayVal() == nil {
return fmt.Errorf("the value of term expression template variable {%s} is not array", expr.GetTemplateVariableName())
}
dataType := expr.GetColumnInfo().GetDataType()
if typeutil.IsArrayType(dataType) {
if len(expr.GetColumnInfo().GetNestedPath()) != 0 {
dataType = expr.GetColumnInfo().GetElementType()
}
}
array := value.GetArrayVal().GetArray()
values := make([]*planpb.GenericValue, len(array))
for i, e := range array {
castedValue, err := castValue(dataType, e)
if err != nil {
return err
}
values[i] = castedValue
}
expr.Values = values
return nil
}
func FillUnaryRangeExpressionValue(expr *planpb.UnaryRangeExpr, templateValues map[string]*planpb.GenericValue) error {
value, ok := templateValues[expr.GetTemplateVariableName()]
if !ok {
return fmt.Errorf("the value of expression template variable name {%s} is not found", expr.GetTemplateVariableName())
}
dataType := expr.GetColumnInfo().GetDataType()
if typeutil.IsArrayType(dataType) {
if len(expr.GetColumnInfo().GetNestedPath()) != 0 {
dataType = expr.GetColumnInfo().GetElementType()
}
}
castedValue, err := castValue(dataType, value)
if err != nil {
return err
}
expr.Value = castedValue
return nil
}
func FillBinaryRangeExpressionValue(expr *planpb.BinaryRangeExpr, templateValues map[string]*planpb.GenericValue) error {
var ok bool
dataType := expr.GetColumnInfo().GetDataType()
if typeutil.IsArrayType(dataType) && len(expr.GetColumnInfo().GetNestedPath()) != 0 {
dataType = expr.GetColumnInfo().GetElementType()
}
lowerValue := expr.GetLowerValue()
if lowerValue == nil || expr.GetLowerTemplateVariableName() != "" {
lowerValue, ok = templateValues[expr.GetLowerTemplateVariableName()]
if !ok {
return fmt.Errorf("the lower value of expression template variable name {%s} is not found", expr.GetLowerTemplateVariableName())
}
castedLowerValue, err := castValue(dataType, lowerValue)
if err != nil {
return err
}
expr.LowerValue = castedLowerValue
}
upperValue := expr.GetUpperValue()
if upperValue == nil || expr.GetUpperTemplateVariableName() != "" {
upperValue, ok = templateValues[expr.GetUpperTemplateVariableName()]
if !ok {
return fmt.Errorf("the upper value of expression template variable name {%s} is not found", expr.GetUpperTemplateVariableName())
}
castedUpperValue, err := castValue(dataType, upperValue)
if err != nil {
return err
}
expr.UpperValue = castedUpperValue
}
if !(expr.GetLowerInclusive() && expr.GetUpperInclusive()) {
if getGenericValue(GreaterEqual(lowerValue, upperValue)).GetBoolVal() {
return errors.New("invalid range: lowerbound is greater than upperbound")
}
} else {
if getGenericValue(Greater(lowerValue, upperValue)).GetBoolVal() {
return errors.New("invalid range: lowerbound is greater than upperbound")
}
}
return nil
}
func FillBinaryArithOpEvalRangeExpressionValue(expr *planpb.BinaryArithOpEvalRangeExpr, templateValues map[string]*planpb.GenericValue) error {
var dataType schemapb.DataType
var err error
var ok bool
if expr.ArithOp == planpb.ArithOpType_ArrayLength {
dataType = schemapb.DataType_Int64
} else {
operand := expr.GetRightOperand()
if operand == nil || expr.GetOperandTemplateVariableName() != "" {
operand, ok = templateValues[expr.GetOperandTemplateVariableName()]
if !ok {
return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName())
}
}
operandExpr := toValueExpr(operand)
lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType
if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) {
lDataType = expr.GetColumnInfo().GetElementType()
}
if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(),
rDataType, schemapb.DataType_None); err != nil {
return err
}
if operand.GetArrayVal() != nil {
return errors.New("can not comparisons array directly")
}
dataType, err = getTargetType(lDataType, rDataType)
if err != nil {
return err
}
castedOperand, err := castValue(dataType, operand)
if err != nil {
return err
}
expr.RightOperand = castedOperand
}
value := expr.GetValue()
if expr.GetValue() == nil || expr.GetValueTemplateVariableName() != "" {
value, ok = templateValues[expr.GetValueTemplateVariableName()]
if !ok {
return fmt.Errorf("the value of expression template variable name {%s} is not found", expr.GetValueTemplateVariableName())
}
}
castedValue, err := castValue(dataType, value)
if err != nil {
return err
}
expr.Value = castedValue
return nil
}
func FillJSONContainsExpressionValue(expr *planpb.JSONContainsExpr, templateValues map[string]*planpb.GenericValue) error {
if expr.GetElements() != nil && expr.GetTemplateVariableName() == "" {
return nil
}
value, ok := templateValues[expr.GetTemplateVariableName()]
if !ok {
return fmt.Errorf("the value of expression template variable name {%s} is not found", expr.GetTemplateVariableName())
}
if err := checkContainsElement(toColumnExpr(expr.GetColumnInfo()), expr.GetOp(), value); err != nil {
return err
}
dataType := expr.GetColumnInfo().GetDataType()
if typeutil.IsArrayType(dataType) {
dataType = expr.GetColumnInfo().GetElementType()
}
if expr.GetOp() == planpb.JSONContainsExpr_Contains {
castedValue, err := castValue(dataType, value)
if err != nil {
return err
}
expr.Elements = append(expr.Elements, castedValue)
} else {
for _, e := range value.GetArrayVal().GetArray() {
castedValue, err := castValue(dataType, e)
if err != nil {
return err
}
expr.Elements = append(expr.Elements, castedValue)
}
}
return nil
}