From 5016e509befcde740ad1f4c331bf4cc2451bed95 Mon Sep 17 00:00:00 2001 From: dragondriver Date: Wed, 14 Jul 2021 18:51:54 +0800 Subject: [PATCH] Support wildcard match on search/query output fields (#6510) Signed-off-by: dragondriver --- internal/proxy/task.go | 29 ++++++++++ internal/proxy/task_test.go | 103 ++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index a2a520e270..d4c71fa05d 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -23,6 +23,7 @@ import ( "runtime" "sort" "strconv" + "strings" "time" "unsafe" @@ -1337,6 +1338,23 @@ func (st *SearchTask) getVChannels() ([]vChan, error) { return st.chMgr.getVChannels(collID) } +// https://github.com/milvus-io/milvus/issues/6411 +// Support wildcard match +func translateOutputFields(outputFields []string, schema *schemapb.CollectionSchema) ([]string, error) { + if len(outputFields) == 1 && strings.TrimSpace(outputFields[0]) == "*" { + ret := make([]string, 0) + // fill all fields except vector fields + for _, field := range schema.Fields { + if field.DataType != schemapb.DataType_BinaryVector && field.DataType != schemapb.DataType_FloatVector { + ret = append(ret, field.Name) + } + } + return ret, nil + } + + return outputFields, nil +} + func (st *SearchTask) PreExecute(ctx context.Context) error { st.Base.MsgType = commonpb.MsgType_Search st.Base.SourceID = Params.ProxyID @@ -1396,6 +1414,13 @@ func (st *SearchTask) PreExecute(ctx context.Context) error { if err != nil { // err is not nil if collection not exists return err } + + outputFields, err := translateOutputFields(st.query.OutputFields, schema) + if err != nil { + return err + } + st.query.OutputFields = outputFields + if st.query.GetDslType() == commonpb.DslType_BoolExprV1 { annsField, err := GetAttrByKeyFromRepeatedKV(AnnsFieldKey, st.query.SearchParams) if err != nil { @@ -2069,6 +2094,10 @@ func (rt *RetrieveTask) PreExecute(ctx context.Context) error { if err != nil { return err } + rt.retrieve.OutputFields, err = translateOutputFields(rt.retrieve.OutputFields, schema) + if err != nil { + return err + } if len(rt.retrieve.OutputFields) == 0 { for _, field := range schema.Fields { if field.FieldID >= 100 && field.DataType != schemapb.DataType_FloatVector && field.DataType != schemapb.DataType_BinaryVector { diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 0bd7469638..cbd8975ce6 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" ) +// TODO(dragondriver): add more test cases + func TestGetNumRowsOfScalarField(t *testing.T) { cases := []struct { datas interface{} @@ -386,3 +388,104 @@ func TestInsertTask_checkRowNums(t *testing.T) { err = case2.checkRowNums() assert.Equal(t, nil, err) } + +func TestTranslateOutputFields(t *testing.T) { + f1 := "field1" + f2 := "field2" + fvec := "fvec" + bvec := "bvec" + all := "*" + allWithWhiteSpace := " * " + allWithLeftWhiteSpace := " *" + allWithRightWhiteSpace := "* " + var outputFields []string + var err error + + // schema has no vector fields + schema1 := &schemapb.CollectionSchema{ + Name: "TestTranslateOutputFields", + Description: "TestTranslateOutputFields", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + {Name: f1, DataType: schemapb.DataType_Int64}, + {Name: f2, DataType: schemapb.DataType_Int64}, + }, + } + + outputFields, err = translateOutputFields([]string{}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{}, outputFields) + + outputFields, err = translateOutputFields([]string{f1}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1}, outputFields) + + outputFields, err = translateOutputFields([]string{f2}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f2}, outputFields) + + outputFields, err = translateOutputFields([]string{f1, f2}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{all}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{allWithWhiteSpace}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{allWithLeftWhiteSpace}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{allWithRightWhiteSpace}, schema1) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + // schema has vector fields + schema2 := &schemapb.CollectionSchema{ + Name: "TestTranslateOutputFields", + Description: "TestTranslateOutputFields", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + {Name: f1, DataType: schemapb.DataType_Int64}, + {Name: f2, DataType: schemapb.DataType_Int64}, + {Name: fvec, DataType: schemapb.DataType_FloatVector}, + {Name: bvec, DataType: schemapb.DataType_BinaryVector}, + }, + } + + outputFields, err = translateOutputFields([]string{}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{}, outputFields) + + outputFields, err = translateOutputFields([]string{f1}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1}, outputFields) + + outputFields, err = translateOutputFields([]string{f2}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f2}, outputFields) + + outputFields, err = translateOutputFields([]string{f1, f2}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{all}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{allWithWhiteSpace}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{allWithLeftWhiteSpace}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) + + outputFields, err = translateOutputFields([]string{allWithRightWhiteSpace}, schema2) + assert.Equal(t, nil, err) + assert.Equal(t, []string{f1, f2}, outputFields) +}