From bdb8396e746d4a396615d1f447232d2bd5d7385c Mon Sep 17 00:00:00 2001 From: groot Date: Tue, 10 Aug 2021 11:59:28 +0800 Subject: [PATCH] Fix CalcDistance wrong result when fetting vectors from collection (#6976) * Fix CalcDistance wrong result when fetting vectors from collection Signed-off-by: yhmo * Fix CalcDistance wrong result when fetting vectors from collection Signed-off-by: yhmo * preset capacity Signed-off-by: yhmo * typo Signed-off-by: yhmo * error check Signed-off-by: yhmo * code lint Signed-off-by: yhmo --- internal/proxy/impl.go | 123 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 110 insertions(+), 13 deletions(-) diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 5a43b6641f..62790e0593 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1360,6 +1360,10 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque zap.Any("partitions", request.PartitionNames), zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data))) defer func() { + idsCount := 0 + if rt.result != nil { + idsCount = len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data) + } log.Debug("Retrieve Done", zap.Error(err), zap.String("role", Params.RoleName), @@ -1368,7 +1372,7 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque zap.String("db", request.DbName), zap.String("collection", request.CollectionName), zap.Any("partitions", request.PartitionNames), - zap.Any("len(Ids)", len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data))) + zap.Any("len(Ids)", idsCount)) }() err = rt.WaitToFinish() @@ -1593,6 +1597,80 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista return node.Retrieve(ctx, retrieveRequest) } + // the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids + arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) { + var retrievedIds *schemapb.ScalarField + var retrievedVectors *schemapb.VectorField + for _, fieldData := range retrievedFields { + if fieldData.FieldName == ids.FieldName { + retrievedVectors = fieldData.GetVectors() + } + if fieldData.Type == schemapb.DataType_Int64 { + retrievedIds = fieldData.GetScalars() + } + } + + if retrievedIds == nil || retrievedVectors == nil { + return nil, errors.New("Failed to fetch vectors") + } + + dict := make(map[int64]int) + for index, id := range retrievedIds.GetLongData().Data { + dict[id] = index + } + + inputIds := ids.IdArray.GetIntId().Data + if retrievedVectors.GetFloatVector() != nil { + floatArr := retrievedVectors.GetFloatVector().Data + element := retrievedVectors.GetDim() + result := make([]float32, 0, int64(len(inputIds))*element) + for _, id := range inputIds { + index, ok := dict[id] + if !ok { + log.Error("id not found in CalcDistance", zap.Int64("id", id)) + return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id)) + } + result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...) + } + + return &schemapb.VectorField{ + Dim: element, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: result, + }, + }, + }, nil + } + + if retrievedVectors.GetBinaryVector() != nil { + binaryArr := retrievedVectors.GetBinaryVector() + element := retrievedVectors.GetDim() + if element%8 != 0 { + element = element + 8 - element%8 + } + + result := make([]byte, 0, int64(len(inputIds))*element) + for _, id := range inputIds { + index, ok := dict[id] + if !ok { + log.Error("id not found in CalcDistance", zap.Int64("id", id)) + return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id)) + } + result = append(result, binaryArr[int64(index)*element:int64(index+1)*element]...) + } + + return &schemapb.VectorField{ + Dim: element * 8, + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: result, + }, + }, nil + } + + return nil, errors.New("Failed to fetch vectors") + } + vectorsLeft := request.GetOpLeft().GetDataArray() opLeft := request.GetOpLeft().GetIdArray() if opLeft != nil { @@ -1606,11 +1684,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista }, nil } - for _, fieldData := range result.FieldsData { - if fieldData.FieldName == opLeft.FieldName { - vectorsLeft = fieldData.GetVectors() - break - } + vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData) + if err != nil { + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil } } @@ -1636,11 +1717,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista }, nil } - for _, fieldData := range result.FieldsData { - if fieldData.FieldName == opRight.FieldName { - vectorsRight = fieldData.GetVectors() - break - } + vectorsRight, err = arrangeFunc(opRight, result.FieldsData) + if err != nil { + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil } } @@ -1653,7 +1737,16 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista }, nil } - if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil { + if vectorsLeft.Dim != vectorsRight.Dim { + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "Vectors dimension is not equal", + }, + }, nil + } + + if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil { distances, err := distance.CalcFloatDistance(vectorsLeft.Dim, vectorsLeft.GetFloatVector().Data, vectorsRight.GetFloatVector().Data, metric) if err != nil { return &milvuspb.CalcDistanceResults{ @@ -1674,7 +1767,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista }, nil } - if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil { + if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil { hamming, err := distance.CalcHammingDistance(vectorsLeft.Dim, vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector()) if err != nil { return &milvuspb.CalcDistanceResults{ @@ -1719,6 +1812,10 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista } err = errors.New("Unexpected error") + if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) { + err = errors.New("Cannot calculate distance between binary vectors and float vectors") + } + return &milvuspb.CalcDistanceResults{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError,