From e3d78f55b3e1cb1418604370db1613cfdff4d2ef Mon Sep 17 00:00:00 2001 From: cqy123456 <39671710+cqy123456@users.noreply.github.com> Date: Mon, 25 Mar 2024 06:19:07 -0500 Subject: [PATCH] fix: add some check fieldata dim (#31564) issue:https://github.com/milvus-io/milvus/issues/30138 Signed-off-by: cqy123456 --- internal/proxy/validate_util.go | 21 +++- internal/proxy/validate_util_test.go | 174 +++++++++++++++++++++++++++ pkg/util/funcutil/func.go | 4 +- 3 files changed, 196 insertions(+), 3 deletions(-) diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index e75c79cde7..2a007adb13 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -127,7 +127,10 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil msg := fmt.Sprintf("the num_rows (%d) of field (%s) is not equal to passed num_rows (%d)", fieldNumRows, fieldName, numRows) return merr.WrapErrParameterInvalid(fieldNumRows, numRows, msg) } - + errDimMismatch := func(fieldName string, dataDim int64, schemaDim int64) error { + msg := fmt.Sprintf("the dim (%d) of field data(%s) is not equal to schema dim (%d)", dataDim, fieldName, schemaDim) + return merr.WrapErrParameterInvalid(dataDim, schemaDim, msg) + } for _, field := range data { switch field.GetType() { case schemapb.DataType_FloatVector: @@ -145,6 +148,10 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil if err != nil { return err } + dataDim := field.GetVectors().Dim + if dataDim != dim { + return errDimMismatch(field.GetFieldName(), dataDim, dim) + } if n != numRows { return errNumRowsMismatch(field.GetFieldName(), n) @@ -160,6 +167,10 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil if err != nil { return err } + dataDim := field.GetVectors().Dim + if dataDim != dim { + return errDimMismatch(field.GetFieldName(), dataDim, dim) + } n, err := funcutil.GetNumRowsOfBinaryVectorField(field.GetVectors().GetBinaryVector(), dim) if err != nil { @@ -180,6 +191,10 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil if err != nil { return err } + dataDim := field.GetVectors().Dim + if dataDim != dim { + return errDimMismatch(field.GetFieldName(), dataDim, dim) + } n, err := funcutil.GetNumRowsOfFloat16VectorField(field.GetVectors().GetFloat16Vector(), dim) if err != nil { @@ -200,6 +215,10 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil if err != nil { return err } + dataDim := field.GetVectors().Dim + if dataDim != dim { + return errDimMismatch(field.GetFieldName(), dataDim, dim) + } n, err := funcutil.GetNumRowsOfBFloat16VectorField(field.GetVectors().GetBfloat16Vector(), dim) if err != nil { diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index f56cdb9d58..c297f00076 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -316,6 +316,48 @@ func Test_validateUtil_checkAligned(t *testing.T) { assert.Error(t, err) }) + t.Run("field_data dim not match schema dim", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_FloatVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}, + }, + }, + Dim: 16, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 1) + + assert.Error(t, err) + }) + t.Run("invalid num rows", func(t *testing.T) { data := []*schemapb.FieldData{ { @@ -328,6 +370,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: []float32{1.1, 2.2}, }, }, + Dim: 8, }, }, }, @@ -369,6 +412,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}, }, }, + Dim: 8, }, }, }, @@ -445,6 +489,46 @@ func Test_validateUtil_checkAligned(t *testing.T) { assert.Error(t, err) }) + t.Run("field data dim not match schema dim", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_BinaryVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: []byte("66666666"), + }, + Dim: 128, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 100) + + assert.Error(t, err) + }) + t.Run("invalid num rows", func(t *testing.T) { data := []*schemapb.FieldData{ { @@ -455,6 +539,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: &schemapb.VectorField_BinaryVector{ BinaryVector: []byte("not128"), }, + Dim: 128, }, }, }, @@ -494,6 +579,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: &schemapb.VectorField_BinaryVector{ BinaryVector: []byte{'1', '2'}, }, + Dim: 8, }, }, }, @@ -580,6 +666,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: &schemapb.VectorField_Float16Vector{ Float16Vector: []byte("not128"), }, + Dim: 128, }, }, }, @@ -619,6 +706,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: &schemapb.VectorField_Float16Vector{ Float16Vector: []byte{'1', '2'}, }, + Dim: 2, }, }, }, @@ -648,6 +736,46 @@ func Test_validateUtil_checkAligned(t *testing.T) { assert.Error(t, err) }) + t.Run("field_data dim not match schema dim", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Float16Vector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: []byte{'1', '2', '3', '4', '5', '6'}, + }, + Dim: 16, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "3", + }, + }, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 1) + + assert.Error(t, err) + }) + ////////////////////////////////////////////////////////////////////// t.Run("bfloat16 vector column not found", func(t *testing.T) { @@ -695,6 +823,46 @@ func Test_validateUtil_checkAligned(t *testing.T) { assert.Error(t, err) }) + t.Run("field_data dim not match schema dim", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_BFloat16Vector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: []byte{'1', '2', '3', '4', '5', '6'}, + }, + Dim: 16, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_BFloat16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "3", + }, + }, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 1) + + assert.Error(t, err) + }) + t.Run("invalid num rows", func(t *testing.T) { data := []*schemapb.FieldData{ { @@ -705,6 +873,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: &schemapb.VectorField_Bfloat16Vector{ Bfloat16Vector: []byte("not128"), }, + Dim: 128, }, }, }, @@ -744,6 +913,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: &schemapb.VectorField_Bfloat16Vector{ Bfloat16Vector: []byte{'1', '2'}, }, + Dim: 2, }, }, }, @@ -830,6 +1000,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: generateFloatVectors(10, 8), }, }, + Dim: 8, }, }, }, @@ -841,6 +1012,7 @@ func Test_validateUtil_checkAligned(t *testing.T) { Data: &schemapb.VectorField_BinaryVector{ BinaryVector: generateBinaryVectors(10, 8), }, + Dim: 8, }, }, }, @@ -1949,6 +2121,7 @@ func Test_validateUtil_Validate(t *testing.T) { Type: schemapb.DataType_FloatVector, Field: &schemapb.FieldData_Vectors{ Vectors: &schemapb.VectorField{ + Dim: 8, Data: &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ Data: generateFloatVectors(2, 8), @@ -1962,6 +2135,7 @@ func Test_validateUtil_Validate(t *testing.T) { Type: schemapb.DataType_BinaryVector, Field: &schemapb.FieldData_Vectors{ Vectors: &schemapb.VectorField{ + Dim: 8, Data: &schemapb.VectorField_BinaryVector{ BinaryVector: generateBinaryVectors(2, 8), }, diff --git a/pkg/util/funcutil/func.go b/pkg/util/funcutil/func.go index 4c7e1dad5a..71e67a55c3 100644 --- a/pkg/util/funcutil/func.go +++ b/pkg/util/funcutil/func.go @@ -266,7 +266,7 @@ func GetNumRowsOfFloat16VectorField(f16Datas []byte, dim int64) (uint64, error) } l := len(f16Datas) if int64(l)%dim != 0 { - return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim) + return 0, fmt.Errorf("the length(%d) of float16 data should divide the dim(%d)", l, dim) } return uint64((int64(l)) / dim / 2), nil } @@ -277,7 +277,7 @@ func GetNumRowsOfBFloat16VectorField(bf16Datas []byte, dim int64) (uint64, error } l := len(bf16Datas) if int64(l)%dim != 0 { - return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim) + return 0, fmt.Errorf("the length(%d) of bfloat data should divide the dim(%d)", l, dim) } return uint64((int64(l)) / dim / 2), nil }