mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Verify vector float data for bulkinsert and insert (#22729)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
This commit is contained in:
parent
407da8beca
commit
3aa28506a2
@ -148,6 +148,30 @@ func (it *insertTask) checkPrimaryFieldData() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) checkVectorFieldData() error {
|
||||
fields := it.GetFieldsData()
|
||||
for _, field := range fields {
|
||||
if field.GetType() != schemapb.DataType_FloatVector {
|
||||
continue
|
||||
}
|
||||
|
||||
vectorField := field.GetVectors()
|
||||
if vectorField == nil || vectorField.GetFloatVector() == nil {
|
||||
log.Error("float vector field is illegal, array type mismatch", zap.String("field name", field.GetFieldName()))
|
||||
return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
}
|
||||
|
||||
floatArray := vectorField.GetFloatVector()
|
||||
err := typeutil.VerifyFloats32(floatArray.GetData())
|
||||
if err != nil {
|
||||
log.Error("float vector field data is illegal", zap.String("field name", field.GetFieldName()), zap.Error(err))
|
||||
return fmt.Errorf("float vector field data is illegal, error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute")
|
||||
defer sp.Finish()
|
||||
@ -229,6 +253,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// check vector field data
|
||||
err = it.checkVectorFieldData()
|
||||
if err != nil {
|
||||
log.Error("vector field data is illegal", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("Proxy Insert PreExecute done", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName))
|
||||
|
||||
return nil
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
@ -12,7 +13,6 @@ import (
|
||||
func TestInsertTask_checkLengthOfFieldsData(t *testing.T) {
|
||||
var err error
|
||||
|
||||
// schema is empty, though won't happen in system
|
||||
case1 := insertTask{
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_checkLengthOfFieldsData",
|
||||
@ -346,3 +346,84 @@ func TestInsertTask_CheckAligned(t *testing.T) {
|
||||
err = case2.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInsertTask_CheckVectorFieldData(t *testing.T) {
|
||||
fieldName := "embeddings"
|
||||
numRows := 10
|
||||
dim := 32
|
||||
task := insertTask{
|
||||
BaseInsertTask: BaseInsertTask{
|
||||
InsertRequest: internalpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
},
|
||||
Version: internalpb.InsertDataVersion_ColumnBased,
|
||||
NumRows: uint64(numRows),
|
||||
},
|
||||
},
|
||||
schema: &schemapb.CollectionSchema{
|
||||
Name: "TestInsertTask_CheckVectorFieldData",
|
||||
Description: "TestInsertTask_CheckVectorFieldData",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: fieldName,
|
||||
IsPrimaryKey: false,
|
||||
AutoID: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// success case
|
||||
task.FieldsData = []*schemapb.FieldData{
|
||||
newFloatVectorFieldData(fieldName, numRows, dim),
|
||||
}
|
||||
err := task.checkVectorFieldData()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// field is nil
|
||||
task.FieldsData = []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: fieldName,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
err = task.checkVectorFieldData()
|
||||
assert.Error(t, err)
|
||||
|
||||
// vector data is not a number
|
||||
values := generateFloatVectors(numRows, dim)
|
||||
values[5] = float32(math.NaN())
|
||||
task.FieldsData[0].Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: values,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err = task.checkVectorFieldData()
|
||||
assert.Error(t, err)
|
||||
|
||||
// vector data is infinity
|
||||
values[5] = float32(math.Inf(1))
|
||||
task.FieldsData[0].Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: values,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err = task.checkVectorFieldData()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@ -21,7 +21,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
@ -119,9 +118,9 @@ func parseFloat(s string, bitsize int, fieldName string) (float64, error) {
|
||||
return 0, fmt.Errorf("failed to parse value '%s' for field '%s', error: %w", s, fieldName, err)
|
||||
}
|
||||
|
||||
// not allow not-a-number and infinity
|
||||
if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) {
|
||||
return 0, fmt.Errorf("value '%s' is not a number or infinity, field '%s', error: %w", s, fieldName, err)
|
||||
err = typeutil.VerifyFloat(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("illegal value '%s' for field '%s', error: %w", s, fieldName, err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
|
||||
@ -298,6 +298,14 @@ func Test_parseFloat(t *testing.T) {
|
||||
value, err = parseFloat("2.718281828459045", 64, "")
|
||||
assert.True(t, math.Abs(value-2.718281828459045) < 0.0000000000000001)
|
||||
assert.Nil(t, err)
|
||||
|
||||
value, err = parseFloat("Inf", 32, "")
|
||||
assert.Zero(t, value)
|
||||
assert.Error(t, err)
|
||||
|
||||
value, err = parseFloat("NaN", 64, "")
|
||||
assert.Zero(t, value)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_InitValidators(t *testing.T) {
|
||||
|
||||
@ -493,6 +493,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
||||
return nil, fmt.Errorf("failed to read float array: %s", err.Error())
|
||||
}
|
||||
|
||||
err = typeutil.VerifyFloats32(data)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in float array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in float array: %s", err.Error())
|
||||
}
|
||||
|
||||
return &storage.FloatFieldData{
|
||||
Data: data,
|
||||
}, nil
|
||||
@ -503,6 +509,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
||||
return nil, fmt.Errorf("failed to read double array: %s", err.Error())
|
||||
}
|
||||
|
||||
err = typeutil.VerifyFloats64(data)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in double array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in double array: %s", err.Error())
|
||||
}
|
||||
|
||||
return &storage.DoubleFieldData{
|
||||
Data: data,
|
||||
}, nil
|
||||
@ -541,6 +553,13 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
||||
log.Error("Numpy parser: failed to read float vector array", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to read float vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
err = typeutil.VerifyFloats32(data)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in float vector array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
} else if elementType == schemapb.DataType_Double {
|
||||
data = make([]float32, 0, columnReader.rowCount)
|
||||
data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension)
|
||||
@ -550,6 +569,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s
|
||||
}
|
||||
|
||||
for _, f64 := range data64 {
|
||||
err = typeutil.VerifyFloat(f64)
|
||||
if err != nil {
|
||||
log.Error("Numpy parser: illegal value in float vector array", zap.Error(err))
|
||||
return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error())
|
||||
}
|
||||
|
||||
data = append(data, float32(f64))
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ package importutil
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@ -402,6 +403,22 @@ func Test_NumpyParserReadData(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
readErrorFunc := func(filedName string, data interface{}) {
|
||||
filePath := TempFilesPath + filedName + ".npy"
|
||||
err = CreateNumpyFile(filePath, data)
|
||||
assert.Nil(t, err)
|
||||
|
||||
readers, err := parser.createReaders([]string{filePath})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(readers))
|
||||
defer closeReaders(readers)
|
||||
|
||||
// encounter error
|
||||
fieldData, err := parser.readData(readers[0], 1000)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, fieldData)
|
||||
}
|
||||
|
||||
t.Run("read bool", func(t *testing.T) {
|
||||
readEmptyFunc("FieldBool", []bool{})
|
||||
|
||||
@ -442,6 +459,8 @@ func Test_NumpyParserReadData(t *testing.T) {
|
||||
|
||||
data := []float32{2.5, 32.2, 53.254, 3.45, 65.23421, 54.8978}
|
||||
readBatchFunc("FieldFloat", data, len(data), func(k int) interface{} { return data[k] })
|
||||
data = []float32{2.5, 32.2, float32(math.NaN())}
|
||||
readErrorFunc("FieldFloat", data)
|
||||
})
|
||||
|
||||
t.Run("read double", func(t *testing.T) {
|
||||
@ -449,6 +468,8 @@ func Test_NumpyParserReadData(t *testing.T) {
|
||||
|
||||
data := []float64{65.24454, 343.4365, 432.6556}
|
||||
readBatchFunc("FieldDouble", data, len(data), func(k int) interface{} { return data[k] })
|
||||
data = []float64{65.24454, math.Inf(1)}
|
||||
readErrorFunc("FieldDouble", data)
|
||||
})
|
||||
|
||||
specialReadEmptyFunc := func(filedName string, data interface{}) {
|
||||
@ -481,6 +502,9 @@ func Test_NumpyParserReadData(t *testing.T) {
|
||||
t.Run("read float vector", func(t *testing.T) {
|
||||
specialReadEmptyFunc("FieldFloatVector", [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}})
|
||||
specialReadEmptyFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, 5, 6}})
|
||||
|
||||
readErrorFunc("FieldFloatVector", [][4]float32{{1, 2, 3, float32(math.NaN())}, {3, 4, 5, 6}})
|
||||
readErrorFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, math.Inf(1), 6}})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
53
internal/util/typeutil/float_util.go
Normal file
53
internal/util/typeutil/float_util.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
func VerifyFloat(value float64) error {
|
||||
// not allow not-a-number and infinity
|
||||
if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) {
|
||||
return fmt.Errorf("value '%f' is not a number or infinity", value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func VerifyFloats32(values []float32) error {
|
||||
for _, f := range values {
|
||||
err := VerifyFloat(float64(f))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func VerifyFloats64(values []float64) error {
|
||||
for _, f := range values {
|
||||
err := VerifyFloat(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
66
internal/util/typeutil/float_util_test.go
Normal file
66
internal/util/typeutil/float_util_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_VerifyFloat(t *testing.T) {
|
||||
var value = math.NaN()
|
||||
err := VerifyFloat(value)
|
||||
assert.Error(t, err)
|
||||
|
||||
value = math.Inf(1)
|
||||
err = VerifyFloat(value)
|
||||
assert.Error(t, err)
|
||||
|
||||
value = math.Inf(-1)
|
||||
err = VerifyFloat(value)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_VerifyFloats32(t *testing.T) {
|
||||
data := []float32{2.5, 32.2, 53.254}
|
||||
err := VerifyFloats32(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data = []float32{2.5, 32.2, 53.254, float32(math.NaN())}
|
||||
err = VerifyFloats32(data)
|
||||
assert.Error(t, err)
|
||||
|
||||
data = []float32{2.5, 32.2, 53.254, float32(math.Inf(1))}
|
||||
err = VerifyFloats32(data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_VerifyFloats64(t *testing.T) {
|
||||
data := []float64{2.5, 32.2, 53.254}
|
||||
err := VerifyFloats64(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data = []float64{2.5, 32.2, 53.254, math.NaN()}
|
||||
err = VerifyFloats64(data)
|
||||
assert.Error(t, err)
|
||||
|
||||
data = []float64{2.5, 32.2, 53.254, math.Inf(-1)}
|
||||
err = VerifyFloats64(data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user