diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index a17ee178b9..62724cf34e 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus/internal/util" "go.uber.org/zap" + "go.uber.org/zap/zapcore" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" @@ -38,10 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/crypto" - "github.com/milvus-io/milvus/internal/util/distance" - "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/timerecord" @@ -3037,16 +3035,6 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista Status: unhealthyStatus(), }, nil } - param, _ := funcutil.GetAttrByKeyFromRepeatedKV("metric", request.GetParams()) - metric, err := distance.ValidateMetricType(param) - if err != nil { - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-CalcDistance") defer sp.Finish() @@ -3080,15 +3068,15 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista queryShardPolicy: roundRobinPolicy, } + items := []zapcore.Field{ + zap.String("collection", queryRequest.CollectionName), + zap.Any("partitions", queryRequest.PartitionNames), + zap.Any("OutputFields", queryRequest.OutputFields), + } + err := node.sched.dqQueue.Enqueue(qt) if err != nil { - log.Debug("CalcDistance queryTask failed to enqueue", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole), - zap.String("db", queryRequest.DbName), - zap.String("collection", queryRequest.CollectionName), - zap.Any("partitions", queryRequest.PartitionNames)) + log.Error("CalcDistance queryTask failed to enqueue", append(items, zap.Error(err))...) return &milvuspb.QueryResults{ Status: &commonpb.Status{ @@ -3098,28 +3086,11 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista }, err } - log.Debug("CalcDistance queryTask enqueued", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole), - zap.Int64("msgID", qt.Base.MsgID), - zap.Uint64("timestamp", qt.Base.Timestamp), - zap.String("db", queryRequest.DbName), - zap.String("collection", queryRequest.CollectionName), - zap.Any("partitions", queryRequest.PartitionNames), - zap.Any("OutputFields", queryRequest.OutputFields)) + log.Debug("CalcDistance queryTask enqueued", items...) err = qt.WaitToFinish() if err != nil { - log.Debug("CalcDistance queryTask failed to WaitToFinish", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole), - zap.Int64("msgID", qt.Base.MsgID), - zap.Uint64("timestamp", qt.Base.Timestamp), - zap.String("db", queryRequest.DbName), - zap.String("collection", queryRequest.CollectionName), - zap.Any("partitions", queryRequest.PartitionNames), - zap.Any("OutputFields", queryRequest.OutputFields)) + log.Error("CalcDistance queryTask failed to WaitToFinish", append(items, zap.Error(err))...) return &milvuspb.QueryResults{ Status: &commonpb.Status{ @@ -3129,15 +3100,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista }, err } - log.Debug("CalcDistance queryTask Done", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole), - zap.Int64("msgID", qt.Base.MsgID), - zap.Uint64("timestamp", qt.Base.Timestamp), - zap.String("db", queryRequest.DbName), - zap.String("collection", queryRequest.CollectionName), - zap.Any("partitions", queryRequest.PartitionNames), - zap.Any("OutputFields", queryRequest.OutputFields)) + log.Debug("CalcDistance queryTask Done", items...) return &milvuspb.QueryResults{ Status: qt.result.Status, @@ -3145,328 +3108,13 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista }, nil } - // the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids - arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) { - var retrievedIds *schemapb.ScalarField - var retrievedVectors *schemapb.VectorField - for _, fieldData := range retrievedFields { - if fieldData.FieldName == ids.FieldName { - retrievedVectors = fieldData.GetVectors() - } - if fieldData.Type == schemapb.DataType_Int64 { - retrievedIds = fieldData.GetScalars() - } - } - - if retrievedIds == nil || retrievedVectors == nil { - return nil, errors.New("failed to fetch vectors") - } - - dict := make(map[int64]int) - for index, id := range retrievedIds.GetLongData().Data { - dict[id] = index - } - - inputIds := ids.IdArray.GetIntId().Data - if retrievedVectors.GetFloatVector() != nil { - floatArr := retrievedVectors.GetFloatVector().Data - element := retrievedVectors.GetDim() - result := make([]float32, 0, int64(len(inputIds))*element) - for _, id := range inputIds { - index, ok := dict[id] - if !ok { - log.Error("id not found in CalcDistance", zap.Int64("id", id)) - return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id)) - } - result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...) - } - - return &schemapb.VectorField{ - Dim: element, - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: result, - }, - }, - }, nil - } - - if retrievedVectors.GetBinaryVector() != nil { - binaryArr := retrievedVectors.GetBinaryVector() - element := retrievedVectors.GetDim() - if element%8 != 0 { - element = element + 8 - element%8 - } - - result := make([]byte, 0, int64(len(inputIds))*element) - for _, id := range inputIds { - index, ok := dict[id] - if !ok { - log.Error("id not found in CalcDistance", zap.Int64("id", id)) - return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id)) - } - result = append(result, binaryArr[int64(index)*element:int64(index+1)*element]...) - } - - return &schemapb.VectorField{ - Dim: element * 8, - Data: &schemapb.VectorField_BinaryVector{ - BinaryVector: result, - }, - }, nil - } - - return nil, errors.New("failed to fetch vectors") + // calcDistanceTask is not a standard task, no need to enqueue + task := &calcDistanceTask{ + traceID: traceID, + queryFunc: query, } - log.Debug("CalcDistance received", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole), - zap.String("metric", metric)) - - vectorsLeft := request.GetOpLeft().GetDataArray() - opLeft := request.GetOpLeft().GetIdArray() - if opLeft != nil { - log.Debug("OpLeft IdArray not empty, Get vectors by id", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - result, err := query(opLeft) - if err != nil { - log.Debug("Failed to get left vectors by id", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } - - log.Debug("OpLeft IdArray not empty, Get vectors by id done", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData) - if err != nil { - log.Debug("Failed to re-arrange left vectors", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } - - log.Debug("Re-arrange left vectors done", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - } - - if vectorsLeft == nil { - msg := "Left vectors array is empty" - log.Debug(msg, - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: msg, - }, - }, nil - } - - vectorsRight := request.GetOpRight().GetDataArray() - opRight := request.GetOpRight().GetIdArray() - if opRight != nil { - log.Debug("OpRight IdArray not empty, Get vectors by id", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - result, err := query(opRight) - if err != nil { - log.Debug("Failed to get right vectors by id", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } - - log.Debug("OpRight IdArray not empty, Get vectors by id done", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - vectorsRight, err = arrangeFunc(opRight, result.FieldsData) - if err != nil { - log.Debug("Failed to re-arrange right vectors", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } - - log.Debug("Re-arrange right vectors done", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - } - - if vectorsRight == nil { - msg := "Right vectors array is empty" - log.Debug(msg, - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: msg, - }, - }, nil - } - - if vectorsLeft.Dim != vectorsRight.Dim { - msg := "Vectors dimension is not equal" - log.Debug(msg, - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: msg, - }, - }, nil - } - - if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil { - distances, err := distance.CalcFloatDistance(vectorsLeft.Dim, vectorsLeft.GetFloatVector().Data, vectorsRight.GetFloatVector().Data, metric) - if err != nil { - log.Debug("Failed to CalcFloatDistance", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } - - log.Debug("CalcFloatDistance done", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""}, - Array: &milvuspb.CalcDistanceResults_FloatDist{ - FloatDist: &schemapb.FloatArray{ - Data: distances, - }, - }, - }, nil - } - - if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil { - hamming, err := distance.CalcHammingDistance(vectorsLeft.Dim, vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector()) - if err != nil { - log.Debug("Failed to CalcHammingDistance", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } - - if metric == distance.HAMMING { - log.Debug("CalcHammingDistance done", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""}, - Array: &milvuspb.CalcDistanceResults_IntDist{ - IntDist: &schemapb.IntArray{ - Data: hamming, - }, - }, - }, nil - } - - if metric == distance.TANIMOTO { - tanimoto, err := distance.CalcTanimotoCoefficient(vectorsLeft.Dim, hamming) - if err != nil { - log.Debug("Failed to CalcTanimotoCoefficient", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil - } - - log.Debug("CalcTanimotoCoefficient done", - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""}, - Array: &milvuspb.CalcDistanceResults_FloatDist{ - FloatDist: &schemapb.FloatArray{ - Data: tanimoto, - }, - }, - }, nil - } - } - - err = errors.New("unexpected error") - if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) { - err = errors.New("cannot calculate distance between binary vectors and float vectors") - } - - log.Debug("Failed to CalcDistance", - zap.Error(err), - zap.String("traceID", traceID), - zap.String("role", typeutil.ProxyRole)) - - return &milvuspb.CalcDistanceResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, - }, nil + return task.Execute(ctx, request) } // GetDdChannel returns the used channel for dd operations. diff --git a/internal/proxy/task_calc_distance.go b/internal/proxy/task_calc_distance.go new file mode 100644 index 0000000000..19dae8729f --- /dev/null +++ b/internal/proxy/task_calc_distance.go @@ -0,0 +1,434 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/util/distance" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" +) + +type calcDistanceTask struct { + traceID string + queryFunc func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) +} + +func (t *calcDistanceTask) arrangeVectorsByIntID(inputIds []int64, sequence map[int64]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) { + if retrievedVectors.GetFloatVector() != nil { + floatArr := retrievedVectors.GetFloatVector().GetData() + element := retrievedVectors.GetDim() + result := make([]float32, 0, int64(len(inputIds))*element) + for _, id := range inputIds { + index, ok := sequence[id] + if !ok { + log.Error("id not found in CalcDistance", zap.Int64("id", id)) + return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id)) + } + result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...) + } + + return &schemapb.VectorField{ + Dim: element, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: result, + }, + }, + }, nil + } + + if retrievedVectors.GetBinaryVector() != nil { + binaryArr := retrievedVectors.GetBinaryVector() + singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim()) + numBytes := singleBitLen / 8 + + result := make([]byte, 0, int64(len(inputIds))*numBytes) + for _, id := range inputIds { + index, ok := sequence[id] + if !ok { + log.Error("id not found in CalcDistance", zap.Int64("id", id)) + return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id)) + } + result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...) + } + + return &schemapb.VectorField{ + Dim: retrievedVectors.GetDim(), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: result, + }, + }, nil + } + + return nil, errors.New("unsupported vector type") +} + +func (t *calcDistanceTask) arrangeVectorsByStrID(inputIds []string, sequence map[string]int, retrievedVectors *schemapb.VectorField) (*schemapb.VectorField, error) { + if retrievedVectors.GetFloatVector() != nil { + floatArr := retrievedVectors.GetFloatVector().GetData() + element := retrievedVectors.GetDim() + result := make([]float32, 0, int64(len(inputIds))*element) + for _, id := range inputIds { + index, ok := sequence[id] + if !ok { + log.Error("id not found in CalcDistance", zap.String("id", id)) + return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id)) + } + result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...) + } + + return &schemapb.VectorField{ + Dim: element, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: result, + }, + }, + }, nil + } + + if retrievedVectors.GetBinaryVector() != nil { + binaryArr := retrievedVectors.GetBinaryVector() + singleBitLen := distance.SingleBitLen(retrievedVectors.GetDim()) + numBytes := singleBitLen / 8 + + result := make([]byte, 0, int64(len(inputIds))*numBytes) + for _, id := range inputIds { + index, ok := sequence[id] + if !ok { + log.Error("id not found in CalcDistance", zap.String("id", id)) + return nil, errors.New("failed to fetch vectors by id: " + fmt.Sprintln(id)) + } + result = append(result, binaryArr[int64(index)*numBytes:int64(index+1)*numBytes]...) + } + + return &schemapb.VectorField{ + Dim: retrievedVectors.GetDim(), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: result, + }, + }, nil + } + + return nil, errors.New("unsupported vector type") +} + +func (t *calcDistanceTask) Execute(ctx context.Context, request *milvuspb.CalcDistanceRequest) (*milvuspb.CalcDistanceResults, error) { + param, _ := funcutil.GetAttrByKeyFromRepeatedKV("metric", request.GetParams()) + metric, err := distance.ValidateMetricType(param) + if err != nil { + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + // the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids + arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) { + var retrievedIds *schemapb.ScalarField + var retrievedVectors *schemapb.VectorField + isStringID := true + for _, fieldData := range retrievedFields { + if fieldData.FieldName == ids.FieldName { + retrievedVectors = fieldData.GetVectors() + } + if fieldData.Type == schemapb.DataType_Int64 || + fieldData.Type == schemapb.DataType_VarChar || + fieldData.Type == schemapb.DataType_String { + retrievedIds = fieldData.GetScalars() + + if fieldData.Type == schemapb.DataType_Int64 { + isStringID = false + } + } + } + + if retrievedIds == nil || retrievedVectors == nil { + return nil, errors.New("failed to fetch vectors") + } + + if isStringID { + dict := make(map[string]int) + for index, id := range retrievedIds.GetStringData().GetData() { + dict[id] = index + } + + inputIds := ids.IdArray.GetStrId().GetData() + return t.arrangeVectorsByStrID(inputIds, dict, retrievedVectors) + } + + dict := make(map[int64]int) + for index, id := range retrievedIds.GetLongData().GetData() { + dict[id] = index + } + + inputIds := ids.IdArray.GetIntId().GetData() + return t.arrangeVectorsByIntID(inputIds, dict, retrievedVectors) + } + + log.Debug("CalcDistance received", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole), + zap.String("metric", metric)) + + vectorsLeft := request.GetOpLeft().GetDataArray() + opLeft := request.GetOpLeft().GetIdArray() + if opLeft != nil { + log.Debug("OpLeft IdArray not empty, Get vectors by id", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + result, err := t.queryFunc(opLeft) + if err != nil { + log.Debug("Failed to get left vectors by id", + zap.Error(err), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + log.Debug("OpLeft IdArray not empty, Get vectors by id done", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData) + if err != nil { + log.Debug("Failed to re-arrange left vectors", + zap.Error(err), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + log.Debug("Re-arrange left vectors done", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + } + + if vectorsLeft == nil { + msg := "Left vectors array is empty" + log.Debug(msg, + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: msg, + }, + }, nil + } + + vectorsRight := request.GetOpRight().GetDataArray() + opRight := request.GetOpRight().GetIdArray() + if opRight != nil { + log.Debug("OpRight IdArray not empty, Get vectors by id", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + result, err := t.queryFunc(opRight) + if err != nil { + log.Debug("Failed to get right vectors by id", + zap.Error(err), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + log.Debug("OpRight IdArray not empty, Get vectors by id done", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + vectorsRight, err = arrangeFunc(opRight, result.FieldsData) + if err != nil { + log.Debug("Failed to re-arrange right vectors", + zap.Error(err), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + log.Debug("Re-arrange right vectors done", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + } + + if vectorsRight == nil { + msg := "Right vectors array is empty" + log.Debug(msg, + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: msg, + }, + }, nil + } + + if vectorsLeft.GetDim() != vectorsRight.GetDim() { + msg := "Vectors dimension is not equal" + log.Debug(msg, + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: msg, + }, + }, nil + } + + if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil { + distances, err := distance.CalcFloatDistance(vectorsLeft.GetDim(), vectorsLeft.GetFloatVector().GetData(), vectorsRight.GetFloatVector().GetData(), metric) + if err != nil { + log.Debug("Failed to CalcFloatDistance", + zap.Error(err), + zap.Int64("leftDim", vectorsLeft.GetDim()), + zap.Int("leftLen", len(vectorsLeft.GetFloatVector().GetData())), + zap.Int64("rightDim", vectorsRight.GetDim()), + zap.Int("rightLen", len(vectorsRight.GetFloatVector().GetData())), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + log.Debug("CalcFloatDistance done", + zap.Error(err), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""}, + Array: &milvuspb.CalcDistanceResults_FloatDist{ + FloatDist: &schemapb.FloatArray{ + Data: distances, + }, + }, + }, nil + } + + if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil { + hamming, err := distance.CalcHammingDistance(vectorsLeft.GetDim(), vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector()) + if err != nil { + log.Debug("Failed to CalcHammingDistance", + zap.Error(err), + zap.Int64("leftDim", vectorsLeft.GetDim()), + zap.Int("leftLen", len(vectorsLeft.GetBinaryVector())), + zap.Int64("rightDim", vectorsRight.GetDim()), + zap.Int("rightLen", len(vectorsRight.GetBinaryVector())), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + if metric == distance.HAMMING { + log.Debug("CalcHammingDistance done", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""}, + Array: &milvuspb.CalcDistanceResults_IntDist{ + IntDist: &schemapb.IntArray{ + Data: hamming, + }, + }, + }, nil + } + + if metric == distance.TANIMOTO { + tanimoto, err := distance.CalcTanimotoCoefficient(vectorsLeft.GetDim(), hamming) + if err != nil { + log.Debug("Failed to CalcTanimotoCoefficient", + zap.Error(err), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil + } + + log.Debug("CalcTanimotoCoefficient done", + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success, Reason: ""}, + Array: &milvuspb.CalcDistanceResults_FloatDist{ + FloatDist: &schemapb.FloatArray{ + Data: tanimoto, + }, + }, + }, nil + } + } + + err = errors.New("unexpected error") + if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) { + err = errors.New("cannot calculate distance between binary vectors and float vectors") + } + + log.Debug("Failed to CalcDistance", + zap.Error(err), + zap.String("traceID", t.traceID), + zap.String("role", typeutil.ProxyRole)) + + return &milvuspb.CalcDistanceResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + }, nil +} diff --git a/internal/proxy/task_calc_distance_test.go b/internal/proxy/task_calc_distance_test.go new file mode 100644 index 0000000000..52a8a6ec77 --- /dev/null +++ b/internal/proxy/task_calc_distance_test.go @@ -0,0 +1,490 @@ +package proxy + +import ( + "context" + "errors" + "testing" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/stretchr/testify/assert" +) + +func TestCalcDistanceTask_arrangeVectorsByStrID(t *testing.T) { + task := &calcDistanceTask{} + + inputIds := make([]string, 0) + inputIds = append(inputIds, "c") + inputIds = append(inputIds, "b") + inputIds = append(inputIds, "a") + + sequence := make(map[string]int) + sequence["a"] = 0 + sequence["b"] = 1 + sequence["c"] = 2 + + dim := 16 + + // float vector + floatValue := make([]float32, 0) + for i := 0; i < dim*3; i++ { + floatValue = append(floatValue, float32(i)) + } + retrievedVectors := &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: floatValue, + }, + }, + } + + result, err := task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors) + assert.Nil(t, err) + + floatResult := result.GetFloatVector().GetData() + for i := 0; i < 3; i++ { + for j := 0; j < dim; j++ { + assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j]) + } + } + + // binary vector + binaryValue := make([]byte, 0) + for i := 0; i < 3*dim/8; i++ { + binaryValue = append(binaryValue, byte(i)) + } + retrievedVectors = &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: binaryValue, + }, + } + + result, err = task.arrangeVectorsByStrID(inputIds, sequence, retrievedVectors) + assert.Nil(t, err) + + binaryResult := result.GetBinaryVector() + numBytes := dim / 8 + for i := 0; i < 3; i++ { + for j := 0; j < numBytes; j++ { + assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j]) + } + } +} + +func TestCalcDistanceTask_arrangeVectorsByIntID(t *testing.T) { + task := &calcDistanceTask{} + + inputIds := make([]int64, 0) + inputIds = append(inputIds, 2) + inputIds = append(inputIds, 0) + inputIds = append(inputIds, 1) + + sequence := make(map[int64]int) + sequence[0] = 0 + sequence[1] = 1 + sequence[2] = 2 + + dim := 16 + + // float vector + floatValue := make([]float32, 0) + for i := 0; i < dim*3; i++ { + floatValue = append(floatValue, float32(i)) + } + retrievedVectors := &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: floatValue, + }, + }, + } + + result, err := task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors) + assert.Nil(t, err) + + floatResult := result.GetFloatVector().GetData() + for i := 0; i < 3; i++ { + for j := 0; j < dim; j++ { + assert.Equal(t, floatValue[dim*sequence[inputIds[i]]+j], floatResult[i*dim+j]) + } + } + + // binary vector + binaryValue := make([]byte, 0) + for i := 0; i < dim*3; i++ { + binaryValue = append(binaryValue, byte(i)) + } + retrievedVectors = &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: binaryValue, + }, + } + + result, err = task.arrangeVectorsByIntID(inputIds, sequence, retrievedVectors) + assert.Nil(t, err) + + binaryResult := result.GetBinaryVector() + numBytes := dim / 8 + for i := 0; i < 3; i++ { + for j := 0; j < numBytes; j++ { + assert.Equal(t, binaryValue[sequence[inputIds[i]]*numBytes+j], binaryResult[i*numBytes+j]) + } + } +} + +func TestCalcDistanceTask_ExecuteFloat(t *testing.T) { + ctx := context.Background() + queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) { + return nil, errors.New("unexpected error") + } + + task := &calcDistanceTask{ + traceID: "dummy", + queryFunc: queryFunc, + } + + request := &milvuspb.CalcDistanceRequest{ + OpLeft: nil, + OpRight: nil, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "L2"}, + }, + } + + // left-op empty + calcResult, err := task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) + + request = &milvuspb.CalcDistanceRequest{ + OpLeft: &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_IdArray{ + IdArray: &milvuspb.VectorIDs{}, + }, + }, + OpRight: nil, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "L2"}, + }, + } + + // left-op query error + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) + + fieldIds := make([]int64, 0) + fieldIds = append(fieldIds, 2) + fieldIds = append(fieldIds, 0) + fieldIds = append(fieldIds, 1) + + dim := 8 + floatValue := make([]float32, 0) + for i := 0; i < dim*3; i++ { + floatValue = append(floatValue, float32(i)) + } + + queryFunc = func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) { + if ids == nil { + return &milvuspb.QueryResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unexpected", + }, + }, nil + } + + return &milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "id", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: fieldIds, + }, + }, + }, + }, + }, + { + Type: schemapb.DataType_FloatVector, + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: floatValue, + }, + }, + }, + }, + }, + }, + }, nil + } + + task.queryFunc = queryFunc + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) + + idArray := &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_IdArray{ + IdArray: &milvuspb.VectorIDs{ + FieldName: "vec", + IdArray: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: fieldIds, + }, + }, + }, + }, + }, + } + request = &milvuspb.CalcDistanceRequest{ + OpLeft: idArray, + OpRight: idArray, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "L2"}, + }, + } + + // success + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode) + + // right-op query error + request.OpRight = nil + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) + + request.OpRight = &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_IdArray{ + IdArray: &milvuspb.VectorIDs{ + FieldName: "kkk", + IdArray: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: fieldIds, + }, + }, + }, + }, + }, + } + + // right-op arrange error + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) + + request.OpRight = &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_DataArray{ + DataArray: &schemapb.VectorField{ + Dim: 5, + }, + }, + } + + // different dimension + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) + + request.OpRight = &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_DataArray{ + DataArray: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: make([]float32, 0), + }, + }, + }, + }, + } + + // calcdistance return error + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) +} + +func TestCalcDistanceTask_ExecuteBinary(t *testing.T) { + ctx := context.Background() + + fieldIds := make([]int64, 0) + fieldIds = append(fieldIds, 2) + fieldIds = append(fieldIds, 0) + fieldIds = append(fieldIds, 1) + + dim := 16 + binaryValue := make([]byte, 0) + for i := 0; i < 3*dim/8; i++ { + binaryValue = append(binaryValue, byte(i)) + } + + queryFunc := func(ids *milvuspb.VectorIDs) (*milvuspb.QueryResults, error) { + if ids == nil { + return &milvuspb.QueryResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unexpected", + }, + }, nil + } + + return &milvuspb.QueryResults{ + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "id", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: fieldIds, + }, + }, + }, + }, + }, + { + Type: schemapb.DataType_FloatVector, + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: binaryValue, + }, + }, + }, + }, + }, + }, nil + } + + idArray := &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_IdArray{ + IdArray: &milvuspb.VectorIDs{ + FieldName: "vec", + IdArray: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: fieldIds, + }, + }, + }, + }, + }, + } + request := &milvuspb.CalcDistanceRequest{ + OpLeft: idArray, + OpRight: idArray, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "HAMMING"}, + }, + } + + task := &calcDistanceTask{ + traceID: "dummy", + queryFunc: queryFunc, + } + + // success + calcResult, err := task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode) + + floatArray := &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_DataArray{ + DataArray: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{}, + }, + }, + } + binaryArray := &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_DataArray{ + DataArray: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: binaryValue, + }, + }, + }, + } + request = &milvuspb.CalcDistanceRequest{ + OpLeft: floatArray, + OpRight: binaryArray, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "HAMMING"}, + }, + } + + // float vs binary + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) + + request = &milvuspb.CalcDistanceRequest{ + OpLeft: binaryArray, + OpRight: binaryArray, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "HAMMING"}, + }, + } + + // hamming + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode) + + request = &milvuspb.CalcDistanceRequest{ + OpLeft: binaryArray, + OpRight: binaryArray, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "TANIMOTO"}, + }, + } + + // tanimoto + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, calcResult.Status.ErrorCode) + + request = &milvuspb.CalcDistanceRequest{ + OpLeft: binaryArray, + OpRight: &milvuspb.VectorsArray{ + Array: &milvuspb.VectorsArray_DataArray{ + DataArray: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: make([]byte, 0), + }, + }, + }, + }, + Params: []*commonpb.KeyValuePair{ + {Key: "metric", Value: "HAMMING"}, + }, + } + + // hamming error + calcResult, err = task.Execute(ctx, request) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, calcResult.Status.ErrorCode) +}