diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 3b977ab2e3..7faf8d679b 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -148,6 +148,30 @@ func (it *insertTask) checkPrimaryFieldData() error { return nil } +func (it *insertTask) checkVectorFieldData() error { + fields := it.GetFieldsData() + for _, field := range fields { + if field.GetType() != schemapb.DataType_FloatVector { + continue + } + + vectorField := field.GetVectors() + if vectorField == nil || vectorField.GetFloatVector() == nil { + log.Error("float vector field is illegal, array type mismatch", zap.String("field name", field.GetFieldName())) + return fmt.Errorf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + } + + floatArray := vectorField.GetFloatVector() + err := typeutil.VerifyFloats32(floatArray.GetData()) + if err != nil { + log.Error("float vector field data is illegal", zap.String("field name", field.GetFieldName()), zap.Error(err)) + return fmt.Errorf("float vector field data is illegal, error: %w", err) + } + } + + return nil +} + func (it *insertTask) PreExecute(ctx context.Context) error { sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute") defer sp.Finish() @@ -229,6 +253,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return err } + // check vector field data + err = it.checkVectorFieldData() + if err != nil { + log.Error("vector field data is illegal", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName), zap.Error(err)) + return err + } + log.Debug("Proxy Insert PreExecute done", zap.Int64("msgID", it.Base.MsgID), zap.String("collection name", collectionName)) return nil diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 2b463d41b2..13acce2666 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -1,6 +1,7 @@ package proxy import ( + "math" "testing" "github.com/milvus-io/milvus-proto/go-api/commonpb" @@ -12,7 +13,6 @@ import ( func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { var err error - // schema is empty, though won't happen in system case1 := insertTask{ schema: &schemapb.CollectionSchema{ Name: "TestInsertTask_checkLengthOfFieldsData", @@ -346,3 +346,84 @@ func TestInsertTask_CheckAligned(t *testing.T) { err = case2.CheckAligned() assert.NoError(t, err) } + +func TestInsertTask_CheckVectorFieldData(t *testing.T) { + fieldName := "embeddings" + numRows := 10 + dim := 32 + task := insertTask{ + BaseInsertTask: BaseInsertTask{ + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + }, + Version: internalpb.InsertDataVersion_ColumnBased, + NumRows: uint64(numRows), + }, + }, + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_CheckVectorFieldData", + Description: "TestInsertTask_CheckVectorFieldData", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: fieldName, + IsPrimaryKey: false, + AutoID: false, + DataType: schemapb.DataType_FloatVector, + }, + }, + }, + } + + // success case + task.FieldsData = []*schemapb.FieldData{ + newFloatVectorFieldData(fieldName, numRows, dim), + } + err := task.checkVectorFieldData() + assert.NoError(t, err) + + // field is nil + task.FieldsData = []*schemapb.FieldData{ + { + Type: schemapb.DataType_FloatVector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: nil, + }, + }, + } + err = task.checkVectorFieldData() + assert.Error(t, err) + + // vector data is not a number + values := generateFloatVectors(numRows, dim) + values[5] = float32(math.NaN()) + task.FieldsData[0].Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: values, + }, + }, + }, + } + err = task.checkVectorFieldData() + assert.Error(t, err) + + // vector data is infinity + values[5] = float32(math.Inf(1)) + task.FieldsData[0].Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: values, + }, + }, + }, + } + err = task.checkVectorFieldData() + assert.Error(t, err) +} diff --git a/internal/util/importutil/import_util.go b/internal/util/importutil/import_util.go index c8a304d5c2..d295209ee4 100644 --- a/internal/util/importutil/import_util.go +++ b/internal/util/importutil/import_util.go @@ -21,7 +21,6 @@ import ( "encoding/json" "errors" "fmt" - "math" "path" "runtime/debug" "strconv" @@ -119,9 +118,9 @@ func parseFloat(s string, bitsize int, fieldName string) (float64, error) { return 0, fmt.Errorf("failed to parse value '%s' for field '%s', error: %w", s, fieldName, err) } - // not allow not-a-number and infinity - if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) { - return 0, fmt.Errorf("value '%s' is not a number or infinity, field '%s', error: %w", s, fieldName, err) + err = typeutil.VerifyFloat(value) + if err != nil { + return 0, fmt.Errorf("illegal value '%s' for field '%s', error: %w", s, fieldName, err) } return value, nil diff --git a/internal/util/importutil/import_util_test.go b/internal/util/importutil/import_util_test.go index afcbe590ec..fb31ae343c 100644 --- a/internal/util/importutil/import_util_test.go +++ b/internal/util/importutil/import_util_test.go @@ -298,6 +298,14 @@ func Test_parseFloat(t *testing.T) { value, err = parseFloat("2.718281828459045", 64, "") assert.True(t, math.Abs(value-2.718281828459045) < 0.0000000000000001) assert.Nil(t, err) + + value, err = parseFloat("Inf", 32, "") + assert.Zero(t, value) + assert.Error(t, err) + + value, err = parseFloat("NaN", 64, "") + assert.Zero(t, value) + assert.Error(t, err) } func Test_InitValidators(t *testing.T) { diff --git a/internal/util/importutil/numpy_parser.go b/internal/util/importutil/numpy_parser.go index 43af9a4726..dcd7a3e172 100644 --- a/internal/util/importutil/numpy_parser.go +++ b/internal/util/importutil/numpy_parser.go @@ -493,6 +493,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s return nil, fmt.Errorf("failed to read float array: %s", err.Error()) } + err = typeutil.VerifyFloats32(data) + if err != nil { + log.Error("Numpy parser: illegal value in float array", zap.Error(err)) + return nil, fmt.Errorf("illegal value in float array: %s", err.Error()) + } + return &storage.FloatFieldData{ Data: data, }, nil @@ -503,6 +509,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s return nil, fmt.Errorf("failed to read double array: %s", err.Error()) } + err = typeutil.VerifyFloats64(data) + if err != nil { + log.Error("Numpy parser: illegal value in double array", zap.Error(err)) + return nil, fmt.Errorf("illegal value in double array: %s", err.Error()) + } + return &storage.DoubleFieldData{ Data: data, }, nil @@ -541,6 +553,13 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s log.Error("Numpy parser: failed to read float vector array", zap.Error(err)) return nil, fmt.Errorf("failed to read float vector array: %s", err.Error()) } + + err = typeutil.VerifyFloats32(data) + if err != nil { + log.Error("Numpy parser: illegal value in float vector array", zap.Error(err)) + return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error()) + } + } else if elementType == schemapb.DataType_Double { data = make([]float32, 0, columnReader.rowCount) data64, err := columnReader.reader.ReadFloat64(rowCount * columnReader.dimension) @@ -550,6 +569,12 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s } for _, f64 := range data64 { + err = typeutil.VerifyFloat(f64) + if err != nil { + log.Error("Numpy parser: illegal value in float vector array", zap.Error(err)) + return nil, fmt.Errorf("illegal value in float vector array: %s", err.Error()) + } + data = append(data, float32(f64)) } } diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go index 49ce4f19fb..d477801818 100644 --- a/internal/util/importutil/numpy_parser_test.go +++ b/internal/util/importutil/numpy_parser_test.go @@ -19,6 +19,7 @@ package importutil import ( "context" "errors" + "math" "os" "testing" @@ -402,6 +403,22 @@ func Test_NumpyParserReadData(t *testing.T) { } } + readErrorFunc := func(filedName string, data interface{}) { + filePath := TempFilesPath + filedName + ".npy" + err = CreateNumpyFile(filePath, data) + assert.Nil(t, err) + + readers, err := parser.createReaders([]string{filePath}) + assert.NoError(t, err) + assert.Equal(t, 1, len(readers)) + defer closeReaders(readers) + + // encounter error + fieldData, err := parser.readData(readers[0], 1000) + assert.Error(t, err) + assert.Nil(t, fieldData) + } + t.Run("read bool", func(t *testing.T) { readEmptyFunc("FieldBool", []bool{}) @@ -442,6 +459,8 @@ func Test_NumpyParserReadData(t *testing.T) { data := []float32{2.5, 32.2, 53.254, 3.45, 65.23421, 54.8978} readBatchFunc("FieldFloat", data, len(data), func(k int) interface{} { return data[k] }) + data = []float32{2.5, 32.2, float32(math.NaN())} + readErrorFunc("FieldFloat", data) }) t.Run("read double", func(t *testing.T) { @@ -449,6 +468,8 @@ func Test_NumpyParserReadData(t *testing.T) { data := []float64{65.24454, 343.4365, 432.6556} readBatchFunc("FieldDouble", data, len(data), func(k int) interface{} { return data[k] }) + data = []float64{65.24454, math.Inf(1)} + readErrorFunc("FieldDouble", data) }) specialReadEmptyFunc := func(filedName string, data interface{}) { @@ -481,6 +502,9 @@ func Test_NumpyParserReadData(t *testing.T) { t.Run("read float vector", func(t *testing.T) { specialReadEmptyFunc("FieldFloatVector", [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}}) specialReadEmptyFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, 5, 6}}) + + readErrorFunc("FieldFloatVector", [][4]float32{{1, 2, 3, float32(math.NaN())}, {3, 4, 5, 6}}) + readErrorFunc("FieldFloatVector", [][4]float64{{1, 2, 3, 4}, {3, 4, math.Inf(1), 6}}) }) } diff --git a/internal/util/typeutil/float_util.go b/internal/util/typeutil/float_util.go new file mode 100644 index 0000000000..bbc7abfebd --- /dev/null +++ b/internal/util/typeutil/float_util.go @@ -0,0 +1,53 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeutil + +import ( + "fmt" + "math" +) + +func VerifyFloat(value float64) error { + // not allow not-a-number and infinity + if math.IsNaN(value) || math.IsInf(value, -1) || math.IsInf(value, 1) { + return fmt.Errorf("value '%f' is not a number or infinity", value) + } + + return nil +} + +func VerifyFloats32(values []float32) error { + for _, f := range values { + err := VerifyFloat(float64(f)) + if err != nil { + return err + } + } + + return nil +} + +func VerifyFloats64(values []float64) error { + for _, f := range values { + err := VerifyFloat(f) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/util/typeutil/float_util_test.go b/internal/util/typeutil/float_util_test.go new file mode 100644 index 0000000000..6ac94aad96 --- /dev/null +++ b/internal/util/typeutil/float_util_test.go @@ -0,0 +1,66 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeutil + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_VerifyFloat(t *testing.T) { + var value = math.NaN() + err := VerifyFloat(value) + assert.Error(t, err) + + value = math.Inf(1) + err = VerifyFloat(value) + assert.Error(t, err) + + value = math.Inf(-1) + err = VerifyFloat(value) + assert.Error(t, err) +} + +func Test_VerifyFloats32(t *testing.T) { + data := []float32{2.5, 32.2, 53.254} + err := VerifyFloats32(data) + assert.NoError(t, err) + + data = []float32{2.5, 32.2, 53.254, float32(math.NaN())} + err = VerifyFloats32(data) + assert.Error(t, err) + + data = []float32{2.5, 32.2, 53.254, float32(math.Inf(1))} + err = VerifyFloats32(data) + assert.Error(t, err) +} + +func Test_VerifyFloats64(t *testing.T) { + data := []float64{2.5, 32.2, 53.254} + err := VerifyFloats64(data) + assert.NoError(t, err) + + data = []float64{2.5, 32.2, 53.254, math.NaN()} + err = VerifyFloats64(data) + assert.Error(t, err) + + data = []float64{2.5, 32.2, 53.254, math.Inf(-1)} + err = VerifyFloats64(data) + assert.Error(t, err) +}