mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Fix CalcDistance wrong result when fetting vectors from collection (#6976)
* Fix CalcDistance wrong result when fetting vectors from collection Signed-off-by: yhmo <yihua.mo@zilliz.com> * Fix CalcDistance wrong result when fetting vectors from collection Signed-off-by: yhmo <yihua.mo@zilliz.com> * preset capacity Signed-off-by: yhmo <yihua.mo@zilliz.com> * typo Signed-off-by: yhmo <yihua.mo@zilliz.com> * error check Signed-off-by: yhmo <yihua.mo@zilliz.com> * code lint Signed-off-by: yhmo <yihua.mo@zilliz.com>
This commit is contained in:
parent
3c3975b5ef
commit
bdb8396e74
@ -1360,6 +1360,10 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque
|
||||
zap.Any("partitions", request.PartitionNames),
|
||||
zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
|
||||
defer func() {
|
||||
idsCount := 0
|
||||
if rt.result != nil {
|
||||
idsCount = len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)
|
||||
}
|
||||
log.Debug("Retrieve Done",
|
||||
zap.Error(err),
|
||||
zap.String("role", Params.RoleName),
|
||||
@ -1368,7 +1372,7 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.Any("partitions", request.PartitionNames),
|
||||
zap.Any("len(Ids)", len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
|
||||
zap.Any("len(Ids)", idsCount))
|
||||
}()
|
||||
|
||||
err = rt.WaitToFinish()
|
||||
@ -1593,6 +1597,80 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
||||
return node.Retrieve(ctx, retrieveRequest)
|
||||
}
|
||||
|
||||
// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
|
||||
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
|
||||
var retrievedIds *schemapb.ScalarField
|
||||
var retrievedVectors *schemapb.VectorField
|
||||
for _, fieldData := range retrievedFields {
|
||||
if fieldData.FieldName == ids.FieldName {
|
||||
retrievedVectors = fieldData.GetVectors()
|
||||
}
|
||||
if fieldData.Type == schemapb.DataType_Int64 {
|
||||
retrievedIds = fieldData.GetScalars()
|
||||
}
|
||||
}
|
||||
|
||||
if retrievedIds == nil || retrievedVectors == nil {
|
||||
return nil, errors.New("Failed to fetch vectors")
|
||||
}
|
||||
|
||||
dict := make(map[int64]int)
|
||||
for index, id := range retrievedIds.GetLongData().Data {
|
||||
dict[id] = index
|
||||
}
|
||||
|
||||
inputIds := ids.IdArray.GetIntId().Data
|
||||
if retrievedVectors.GetFloatVector() != nil {
|
||||
floatArr := retrievedVectors.GetFloatVector().Data
|
||||
element := retrievedVectors.GetDim()
|
||||
result := make([]float32, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := dict[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: result,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if retrievedVectors.GetBinaryVector() != nil {
|
||||
binaryArr := retrievedVectors.GetBinaryVector()
|
||||
element := retrievedVectors.GetDim()
|
||||
if element%8 != 0 {
|
||||
element = element + 8 - element%8
|
||||
}
|
||||
|
||||
result := make([]byte, 0, int64(len(inputIds))*element)
|
||||
for _, id := range inputIds {
|
||||
index, ok := dict[id]
|
||||
if !ok {
|
||||
log.Error("id not found in CalcDistance", zap.Int64("id", id))
|
||||
return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id))
|
||||
}
|
||||
result = append(result, binaryArr[int64(index)*element:int64(index+1)*element]...)
|
||||
}
|
||||
|
||||
return &schemapb.VectorField{
|
||||
Dim: element * 8,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: result,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("Failed to fetch vectors")
|
||||
}
|
||||
|
||||
vectorsLeft := request.GetOpLeft().GetDataArray()
|
||||
opLeft := request.GetOpLeft().GetIdArray()
|
||||
if opLeft != nil {
|
||||
@ -1606,11 +1684,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
||||
}, nil
|
||||
}
|
||||
|
||||
for _, fieldData := range result.FieldsData {
|
||||
if fieldData.FieldName == opLeft.FieldName {
|
||||
vectorsLeft = fieldData.GetVectors()
|
||||
break
|
||||
}
|
||||
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
|
||||
if err != nil {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -1636,11 +1717,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
||||
}, nil
|
||||
}
|
||||
|
||||
for _, fieldData := range result.FieldsData {
|
||||
if fieldData.FieldName == opRight.FieldName {
|
||||
vectorsRight = fieldData.GetVectors()
|
||||
break
|
||||
}
|
||||
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
|
||||
if err != nil {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -1653,7 +1737,16 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
|
||||
if vectorsLeft.Dim != vectorsRight.Dim {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "Vectors dimension is not equal",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
|
||||
distances, err := distance.CalcFloatDistance(vectorsLeft.Dim, vectorsLeft.GetFloatVector().Data, vectorsRight.GetFloatVector().Data, metric)
|
||||
if err != nil {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
@ -1674,7 +1767,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
||||
}, nil
|
||||
}
|
||||
|
||||
if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
|
||||
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
|
||||
hamming, err := distance.CalcHammingDistance(vectorsLeft.Dim, vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
|
||||
if err != nil {
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
@ -1719,6 +1812,10 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
||||
}
|
||||
|
||||
err = errors.New("Unexpected error")
|
||||
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
|
||||
err = errors.New("Cannot calculate distance between binary vectors and float vectors")
|
||||
}
|
||||
|
||||
return &milvuspb.CalcDistanceResults{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user