diff --git a/internal/common/common.go b/internal/common/common.go index a3330c2ab3..eaf7257846 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -92,3 +92,7 @@ const ( PropertiesKey string = "properties" TraceIDKey string = "uber-trace-id" ) + +func IsSystemField(fieldID int64) bool { + return fieldID < StartOfUserFieldID +} diff --git a/internal/common/common_test.go b/internal/common/common_test.go new file mode 100644 index 0000000000..7228b1b6ab --- /dev/null +++ b/internal/common/common_test.go @@ -0,0 +1,40 @@ +package common + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsSystemField(t *testing.T) { + type args struct { + fieldID int64 + } + tests := []struct { + name string + args args + want bool + }{ + { + args: args{fieldID: StartOfUserFieldID}, + want: false, + }, + { + args: args{fieldID: StartOfUserFieldID + 1}, + want: false, + }, + { + args: args{fieldID: TimeStampField}, + want: true, + }, + { + args: args{fieldID: RowIDField}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, IsSystemField(tt.args.fieldID), "IsSystemField(%v)", tt.args.fieldID) + }) + } +} diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 9aaef8f5c3..d37d9a7d85 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -51,6 +51,7 @@ type queryTask struct { ids *schemapb.IDs collectionName string queryParams *queryParams + schema *schemapb.CollectionSchema resultBuf chan *internalpb.RetrieveResults toReduceResults []*internalpb.RetrieveResults @@ -111,6 +112,16 @@ func translateToOutputFieldIDs(outputFields []string, schema *schemapb.Collectio return outputFieldIDs, nil } +func filterSystemFields(outputFieldIDs []UniqueID) []UniqueID { + filtered := make([]UniqueID, 0, len(outputFieldIDs)) + for _, outputFieldID := range outputFieldIDs { + if !common.IsSystemField(outputFieldID) { + filtered = append(filtered, outputFieldID) + } + } + return filtered +} + // parseQueryParams get limit and offset from queryParamsPair, both are optional. func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, error) { var ( @@ -232,6 +243,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { } schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName) + t.schema = schema if t.ids != nil { pkField := "" @@ -359,7 +371,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0) tr.CtxRecord(ctx, "reduceResultStart") - t.result, err = reduceRetrieveResults(ctx, t.toReduceResults, t.queryParams) + t.result, err = reduceRetrieveResultsAndFillIfEmpty(ctx, t.toReduceResults, t.queryParams, t.GetOutputFieldsId(), t.schema) if err != nil { return err } @@ -515,6 +527,21 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re return ret, nil } +func reduceRetrieveResultsAndFillIfEmpty(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams, outputFieldsID []int64, schema *schemapb.CollectionSchema) (*milvuspb.QueryResults, error) { + result, err := reduceRetrieveResults(ctx, retrieveResults, queryParams) + if err != nil { + return nil, err + } + + // filter system fields. + filtered := filterSystemFields(outputFieldsID) + if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewMilvusResult(result), filtered, schema); err != nil { + return nil, fmt.Errorf("failed to fill retrieve results: %s", err.Error()) + } + + return result, nil +} + func (t *queryTask) TraceCtx() context.Context { return t.ctx } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index ceccd6e9f2..f4e24db894 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -657,3 +657,9 @@ func getFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, return fieldData } + +func Test_filterSystemFields(t *testing.T) { + outputFieldIDs := []UniqueID{common.RowIDField, common.TimeStampField, common.StartOfUserFieldID} + filtered := filterSystemFields(outputFieldIDs) + assert.ElementsMatch(t, []UniqueID{common.StartOfUserFieldID}, filtered) +} diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 1301bd2cbe..4a92574970 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -1073,7 +1073,7 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que traceID, req.GetFromShardLeader(), dmlChannel, req.GetSegmentIDs())) results = append(results, streamingResult) - ret, err2 := mergeInternalRetrieveResult(ctx, results, req.Req.GetLimit()) + ret, err2 := mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, req.Req.GetLimit(), req.GetReq().GetOutputFieldsId(), qs.collection.Schema()) if err2 != nil { failRet.Status.Reason = err2.Error() return failRet, nil @@ -1115,6 +1115,13 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i }, } + coll, err := node.metaReplica.getCollectionByID(req.GetReq().GetCollectionID()) + if err != nil { + failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError + failRet.Status.Reason = err.Error() + return failRet, nil + } + toMergeResults := make([]*internalpb.RetrieveResults, 0) runningGp, runningCtx := errgroup.WithContext(ctx) mu := &sync.Mutex{} @@ -1149,7 +1156,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i if err := runningGp.Wait(); err != nil { return failRet, nil } - ret, err := mergeInternalRetrieveResult(ctx, toMergeResults, req.GetReq().GetLimit()) + ret, err := mergeInternalRetrieveResultsAndFillIfEmpty(ctx, toMergeResults, req.GetReq().GetLimit(), req.GetReq().GetOutputFieldsId(), coll.Schema()) if err != nil { failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError failRet.Status.Reason = err.Error() diff --git a/internal/querynode/result.go b/internal/querynode/result.go index a90853d2ed..ffa328b9cf 100644 --- a/internal/querynode/result.go +++ b/internal/querynode/result.go @@ -377,6 +377,46 @@ func mergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore return ret, nil } +func mergeSegcoreRetrieveResultsAndFillIfEmpty( + ctx context.Context, + retrieveResults []*segcorepb.RetrieveResults, + limit int64, + outputFieldsID []int64, + schema *schemapb.CollectionSchema, +) (*segcorepb.RetrieveResults, error) { + + mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults, limit) + if err != nil { + return nil, err + } + + if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewSegcoreResults(mergedResult), outputFieldsID, schema); err != nil { + return nil, fmt.Errorf("failed to fill segcore retrieve results: %s", err.Error()) + } + + return mergedResult, nil +} + +func mergeInternalRetrieveResultsAndFillIfEmpty( + ctx context.Context, + retrieveResults []*internalpb.RetrieveResults, + limit int64, + outputFieldsID []int64, + schema *schemapb.CollectionSchema, +) (*internalpb.RetrieveResults, error) { + + mergedResult, err := mergeInternalRetrieveResult(ctx, retrieveResults, limit) + if err != nil { + return nil, err + } + + if err := typeutil.FillRetrieveResultIfEmpty(typeutil.NewInternalResult(mergedResult), outputFieldsID, schema); err != nil { + return nil, fmt.Errorf("failed to fill internal retrieve results: %s", err.Error()) + } + + return mergedResult, nil +} + // func printSearchResultData(data *schemapb.SearchResultData, header string) { // size := len(data.Ids.GetIntId().Data) // if size != len(data.Scores) { diff --git a/internal/querynode/task_query.go b/internal/querynode/task_query.go index 9e9fbfa34e..b620dac01c 100644 --- a/internal/querynode/task_query.go +++ b/internal/querynode/task_query.go @@ -61,7 +61,7 @@ func (q *queryTask) queryOnStreaming() error { } // check if collection has been released, check streaming since it's released first - _, err := q.QS.metaReplica.getCollectionByID(q.CollectionID) + coll, err := q.QS.metaReplica.getCollectionByID(q.CollectionID) if err != nil { return err } @@ -84,7 +84,7 @@ func (q *queryTask) queryOnStreaming() error { } q.tr.RecordSpan() - mergedResult, err := mergeSegcoreRetrieveResults(ctx, sResults, q.iReq.GetLimit()) + mergedResult, err := mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, sResults, q.iReq.GetLimit(), q.iReq.GetOutputFieldsId(), coll.Schema()) if err != nil { return err } @@ -106,7 +106,7 @@ func (q *queryTask) queryOnHistorical() error { } // check if collection has been released, check historical since it's released first - _, err := q.QS.metaReplica.getCollectionByID(q.CollectionID) + coll, err := q.QS.metaReplica.getCollectionByID(q.CollectionID) if err != nil { return err } @@ -127,10 +127,11 @@ func (q *queryTask) queryOnHistorical() error { return err } - mergedResult, err := mergeSegcoreRetrieveResults(ctx, retrieveResults, q.req.GetReq().GetLimit()) + mergedResult, err := mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, retrieveResults, q.req.GetReq().GetLimit(), q.iReq.GetOutputFieldsId(), coll.Schema()) if err != nil { return err } + q.Ret = &internalpb.RetrieveResults{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Ids: mergedResult.Ids, diff --git a/internal/util/typeutil/gen_empty_field_data.go b/internal/util/typeutil/gen_empty_field_data.go new file mode 100644 index 0000000000..013f5b6141 --- /dev/null +++ b/internal/util/typeutil/gen_empty_field_data.go @@ -0,0 +1,163 @@ +package typeutil + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" +) + +func fieldDataEmpty(data *schemapb.FieldData) bool { + if data == nil { + return true + } + switch realData := data.Field.(type) { + case *schemapb.FieldData_Scalars: + switch realScalars := realData.Scalars.Data.(type) { + case *schemapb.ScalarField_BoolData: + return len(realScalars.BoolData.GetData()) <= 0 + case *schemapb.ScalarField_LongData: + return len(realScalars.LongData.GetData()) <= 0 + case *schemapb.ScalarField_FloatData: + return len(realScalars.FloatData.GetData()) <= 0 + case *schemapb.ScalarField_DoubleData: + return len(realScalars.DoubleData.GetData()) <= 0 + case *schemapb.ScalarField_StringData: + return len(realScalars.StringData.GetData()) <= 0 + } + case *schemapb.FieldData_Vectors: + switch realVectors := realData.Vectors.Data.(type) { + case *schemapb.VectorField_BinaryVector: + return len(realVectors.BinaryVector) <= 0 + case *schemapb.VectorField_FloatVector: + return len(realVectors.FloatVector.Data) <= 0 + } + } + return true +} + +func genEmptyBoolFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{BoolData: &schemapb.BoolArray{Data: nil}}, + }, + }, + FieldId: field.GetFieldID(), + } +} + +func genEmptyIntFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: nil}}, + }, + }, + FieldId: field.GetFieldID(), + } +} + +func genEmptyFloatFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{FloatData: &schemapb.FloatArray{Data: nil}}, + }, + }, + FieldId: field.GetFieldID(), + } +} + +func genEmptyDoubleFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{DoubleData: &schemapb.DoubleArray{Data: nil}}, + }, + }, + FieldId: field.GetFieldID(), + } +} + +func genEmptyVarCharFieldData(field *schemapb.FieldSchema) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: nil}}, + }, + }, + FieldId: field.GetFieldID(), + } +} + +func genEmptyBinaryVectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) { + dim, err := GetDim(field) + if err != nil { + return nil, err + } + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: nil, + }, + }, + }, + FieldId: field.GetFieldID(), + }, nil +} + +func genEmptyFloatVectorFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) { + dim, err := GetDim(field) + if err != nil { + return nil, err + } + return &schemapb.FieldData{ + Type: field.GetDataType(), + FieldName: field.GetName(), + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{Data: nil}, + }, + }, + }, + FieldId: field.GetFieldID(), + }, nil +} + +func genEmptyFieldData(field *schemapb.FieldSchema) (*schemapb.FieldData, error) { + dataType := field.GetDataType() + switch dataType { + case schemapb.DataType_Bool: + return genEmptyBoolFieldData(field), nil + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64: + return genEmptyIntFieldData(field), nil + case schemapb.DataType_Float: + return genEmptyFloatFieldData(field), nil + case schemapb.DataType_Double: + return genEmptyDoubleFieldData(field), nil + case schemapb.DataType_VarChar: + return genEmptyVarCharFieldData(field), nil + case schemapb.DataType_BinaryVector: + return genEmptyBinaryVectorFieldData(field) + case schemapb.DataType_FloatVector: + return genEmptyFloatVectorFieldData(field) + default: + return nil, fmt.Errorf("unsupported data type: %s", dataType.String()) + } +} diff --git a/internal/util/typeutil/get_dim.go b/internal/util/typeutil/get_dim.go new file mode 100644 index 0000000000..47f6840ec9 --- /dev/null +++ b/internal/util/typeutil/get_dim.go @@ -0,0 +1,25 @@ +package typeutil + +import ( + "fmt" + "strconv" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" +) + +// GetDim get dimension of field. Maybe also helpful outside. +func GetDim(field *schemapb.FieldSchema) (int64, error) { + if !IsVectorType(field.GetDataType()) { + return 0, fmt.Errorf("%s is not of vector type", field.GetDataType()) + } + h := NewKvPairs(append(field.GetIndexParams(), field.GetTypeParams()...)) + dimStr, err := h.Get("dim") + if err != nil { + return 0, fmt.Errorf("dim not found") + } + dim, err := strconv.Atoi(dimStr) + if err != nil { + return 0, fmt.Errorf("invalid dimension: %s", dimStr) + } + return int64(dim), nil +} diff --git a/internal/util/typeutil/kv_pair_helper.go b/internal/util/typeutil/kv_pair_helper.go new file mode 100644 index 0000000000..b547b10a05 --- /dev/null +++ b/internal/util/typeutil/kv_pair_helper.go @@ -0,0 +1,31 @@ +package typeutil + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" +) + +type kvPairsHelper[K comparable, V any] struct { + kvPairs map[K]V +} + +func (h *kvPairsHelper[K, V]) Get(k K) (V, error) { + v, ok := h.kvPairs[k] + if !ok { + return v, fmt.Errorf("%v not found", k) + } + return v, nil +} + +func NewKvPairs(pairs []*commonpb.KeyValuePair) *kvPairsHelper[string, string] { + helper := &kvPairsHelper[string, string]{ + kvPairs: make(map[string]string), + } + + for _, pair := range pairs { + helper.kvPairs[pair.GetKey()] = pair.GetValue() + } + + return helper +} diff --git a/internal/util/typeutil/kv_pair_helper_test.go b/internal/util/typeutil/kv_pair_helper_test.go new file mode 100644 index 0000000000..1791647a65 --- /dev/null +++ b/internal/util/typeutil/kv_pair_helper_test.go @@ -0,0 +1,20 @@ +package typeutil + +import ( + "testing" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/stretchr/testify/assert" +) + +func TestNewKvPairs(t *testing.T) { + kvPairs := []*commonpb.KeyValuePair{ + {Key: "dim", Value: "128"}, + } + h := NewKvPairs(kvPairs) + v, err := h.Get("dim") + assert.NoError(t, err) + assert.Equal(t, "128", v) + _, err = h.Get("not_exist") + assert.Error(t, err) +} diff --git a/internal/util/typeutil/result_helper.go b/internal/util/typeutil/result_helper.go new file mode 100644 index 0000000000..9a4cdb6b17 --- /dev/null +++ b/internal/util/typeutil/result_helper.go @@ -0,0 +1,41 @@ +package typeutil + +import ( + "github.com/milvus-io/milvus-proto/go-api/schemapb" +) + +func preHandleEmptyResult(result RetrieveResults) { + result.PreHandle() +} + +func appendFieldData(result RetrieveResults, fieldData *schemapb.FieldData) { + result.AppendFieldData(fieldData) +} + +func FillRetrieveResultIfEmpty(result RetrieveResults, outputFieldIds []int64, schema *schemapb.CollectionSchema) error { + if !result.ResultEmpty() { + return nil + } + + preHandleEmptyResult(result) + + helper, err := CreateSchemaHelper(schema) + if err != nil { + return err + } + for _, outputFieldID := range outputFieldIds { + field, err := helper.GetFieldFromID(outputFieldID) + if err != nil { + return err + } + + emptyFieldData, err := genEmptyFieldData(field) + if err != nil { + return err + } + + appendFieldData(result, emptyFieldData) + } + + return nil +} diff --git a/internal/util/typeutil/result_helper_test.go b/internal/util/typeutil/result_helper_test.go new file mode 100644 index 0000000000..ffad286482 --- /dev/null +++ b/internal/util/typeutil/result_helper_test.go @@ -0,0 +1,253 @@ +package typeutil + +import ( + "testing" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + + "github.com/milvus-io/milvus/internal/proto/internalpb" + + "github.com/milvus-io/milvus/internal/proto/segcorepb" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" +) + +func TestGenEmptyFieldData(t *testing.T) { + allTypes := []schemapb.DataType{ + schemapb.DataType_Bool, + schemapb.DataType_Int8, + schemapb.DataType_Int16, + schemapb.DataType_Int32, + schemapb.DataType_Int64, + schemapb.DataType_Float, + schemapb.DataType_Double, + schemapb.DataType_VarChar, + } + allUnsupportedTypes := []schemapb.DataType{ + schemapb.DataType_String, + schemapb.DataType_None, + } + vectorTypes := []schemapb.DataType{ + schemapb.DataType_BinaryVector, + schemapb.DataType_FloatVector, + } + + field := &schemapb.FieldSchema{Name: "field_name", FieldID: 100} + for _, dataType := range allTypes { + field.DataType = dataType + fieldData, err := genEmptyFieldData(field) + assert.NoError(t, err) + assert.Equal(t, dataType, fieldData.GetType()) + assert.Equal(t, field.GetName(), fieldData.GetFieldName()) + assert.True(t, fieldDataEmpty(fieldData)) + assert.Equal(t, field.GetFieldID(), fieldData.GetFieldId()) + } + + for _, dataType := range allUnsupportedTypes { + field.DataType = dataType + _, err := genEmptyFieldData(field) + assert.Error(t, err) + } + + // dim not found + for _, dataType := range vectorTypes { + field.DataType = dataType + _, err := genEmptyFieldData(field) + assert.Error(t, err) + } + + field.TypeParams = []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}} + for _, dataType := range vectorTypes { + field.DataType = dataType + fieldData, err := genEmptyFieldData(field) + assert.NoError(t, err) + assert.Equal(t, dataType, fieldData.GetType()) + assert.Equal(t, field.GetName(), fieldData.GetFieldName()) + assert.True(t, fieldDataEmpty(fieldData)) + assert.Equal(t, field.GetFieldID(), fieldData.GetFieldId()) + } +} + +func TestFillIfEmpty(t *testing.T) { + t.Run("not empty, do nothing", func(t *testing.T) { + result := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2}, + }, + }, + }, + } + err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, nil) + assert.NoError(t, err) + }) + + t.Run("invalid schema", func(t *testing.T) { + result := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: nil, + }, + }, + }, + } + err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, nil) + assert.Error(t, err) + }) + + t.Run("field not found", func(t *testing.T) { + result := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: nil, + }, + }, + }, + } + schema := &schemapb.CollectionSchema{ + Name: "collection", + Description: "description", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + DataType: schemapb.DataType_Int64, + }, + }, + } + err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{101}, schema) + assert.Error(t, err) + }) + + t.Run("unsupported data type", func(t *testing.T) { + result := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: nil, + }, + }, + }, + } + schema := &schemapb.CollectionSchema{ + Name: "collection", + Description: "description", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + DataType: schemapb.DataType_String, + }, + }, + } + err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100}, schema) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + result := &segcorepb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: nil, + }, + }, + }, + } + schema := &schemapb.CollectionSchema{ + Name: "collection", + Description: "description", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "field100", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 101, + Name: "field101", + DataType: schemapb.DataType_VarChar, + }, + }, + } + err := FillRetrieveResultIfEmpty(NewSegcoreResults(result), []int64{100, 101}, schema) + assert.NoError(t, err) + assert.Nil(t, result.GetOffset()) + assert.Equal(t, 2, len(result.GetFieldsData())) + for _, fieldData := range result.GetFieldsData() { + assert.True(t, fieldDataEmpty(fieldData)) + } + }) + + t.Run("normal case", func(t *testing.T) { + result := &internalpb.RetrieveResults{ + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: nil, + }, + }, + }, + } + schema := &schemapb.CollectionSchema{ + Name: "collection", + Description: "description", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "field100", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 101, + Name: "field101", + DataType: schemapb.DataType_VarChar, + }, + }, + } + err := FillRetrieveResultIfEmpty(NewInternalResult(result), []int64{100, 101}, schema) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + for _, fieldData := range result.GetFieldsData() { + assert.True(t, fieldDataEmpty(fieldData)) + } + }) + + t.Run("normal case", func(t *testing.T) { + result := &milvuspb.QueryResults{ + FieldsData: nil, + } + schema := &schemapb.CollectionSchema{ + Name: "collection", + Description: "description", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "field100", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 101, + Name: "field101", + DataType: schemapb.DataType_VarChar, + }, + }, + } + err := FillRetrieveResultIfEmpty(NewMilvusResult(result), []int64{100, 101}, schema) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + for _, fieldData := range result.GetFieldsData() { + assert.True(t, fieldDataEmpty(fieldData)) + } + }) +} diff --git a/internal/util/typeutil/retrieve_result.go b/internal/util/typeutil/retrieve_result.go new file mode 100644 index 0000000000..8b9fc78207 --- /dev/null +++ b/internal/util/typeutil/retrieve_result.go @@ -0,0 +1,76 @@ +package typeutil + +import ( + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" +) + +type RetrieveResults interface { + PreHandle() + ResultEmpty() bool + AppendFieldData(fieldData *schemapb.FieldData) +} + +type segcoreResults struct { + result *segcorepb.RetrieveResults +} + +func (r *segcoreResults) PreHandle() { + r.result.Offset = nil + r.result.FieldsData = nil +} + +func (r *segcoreResults) AppendFieldData(fieldData *schemapb.FieldData) { + r.result.FieldsData = append(r.result.FieldsData, fieldData) +} + +func (r *segcoreResults) ResultEmpty() bool { + return GetSizeOfIDs(r.result.GetIds()) <= 0 +} + +func NewSegcoreResults(result *segcorepb.RetrieveResults) RetrieveResults { + return &segcoreResults{result: result} +} + +type internalResults struct { + result *internalpb.RetrieveResults +} + +func (r *internalResults) PreHandle() { + r.result.FieldsData = nil +} + +func (r *internalResults) AppendFieldData(fieldData *schemapb.FieldData) { + r.result.FieldsData = append(r.result.FieldsData, fieldData) +} + +func (r *internalResults) ResultEmpty() bool { + return GetSizeOfIDs(r.result.GetIds()) <= 0 +} + +func NewInternalResult(result *internalpb.RetrieveResults) RetrieveResults { + return &internalResults{result: result} +} + +type milvusResults struct { + result *milvuspb.QueryResults +} + +func (r *milvusResults) PreHandle() { + r.result.FieldsData = nil +} + +func (r *milvusResults) AppendFieldData(fieldData *schemapb.FieldData) { + r.result.FieldsData = append(r.result.FieldsData, fieldData) +} + +func (r *milvusResults) ResultEmpty() bool { + // not very clear. + return len(r.result.GetFieldsData()) <= 0 +} + +func NewMilvusResult(result *milvuspb.QueryResults) RetrieveResults { + return &milvusResults{result: result} +}