diff --git a/internal/util/importutil/import_util.go b/internal/util/importutil/import_util.go index 3b0341ae98..2379379e92 100644 --- a/internal/util/importutil/import_util.go +++ b/internal/util/importutil/import_util.go @@ -18,6 +18,7 @@ package importutil import ( "context" + "encoding/json" "errors" "fmt" "path" @@ -127,14 +128,15 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ } // json decoder parse all the numeric value into float64 - numericValidator := func(obj interface{}) error { - switch obj.(type) { - case float64: - return nil - default: - return fmt.Errorf("illegal numeric value %v", obj) + numericValidator := func(fieldName string) func(obj interface{}) error { + return func(obj interface{}) error { + switch obj.(type) { + case json.Number: + return nil + default: + return fmt.Errorf("illegal value %v for numeric type field '%s'", obj, fieldName) + } } - } for i := 0; i < len(collectionSchema.Fields); i++ { @@ -153,7 +155,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ case bool: return nil default: - return fmt.Errorf("illegal value %v for bool type field '%s'", obj, schema.GetName()) + return fmt.Errorf("illegal value '%v' for bool type field '%s'", obj, schema.GetName()) } } @@ -164,49 +166,73 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ return nil } case schemapb.DataType_Float: - validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].validateFunc = numericValidator(schema.GetName()) validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := float32(obj.(float64)) - field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value) + value, err := strconv.ParseFloat(string(obj.(json.Number)), 32) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for float type field '%s', error: %w", + obj, schema.GetName(), err) + } + field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, float32(value)) field.(*storage.FloatFieldData).NumRows[0]++ return nil } case schemapb.DataType_Double: - validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].validateFunc = numericValidator(schema.GetName()) validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := obj.(float64) + value, err := strconv.ParseFloat(string(obj.(json.Number)), 32) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for double type field '%s', error: %w", + obj, schema.GetName(), err) + } field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value) field.(*storage.DoubleFieldData).NumRows[0]++ return nil } case schemapb.DataType_Int8: - validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].validateFunc = numericValidator(schema.GetName()) validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int8(obj.(float64)) - field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value) + value, err := strconv.ParseInt(string(obj.(json.Number)), 10, 8) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int8 type field '%s', error: %w", + obj, schema.GetName(), err) + } + field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, int8(value)) field.(*storage.Int8FieldData).NumRows[0]++ return nil } case schemapb.DataType_Int16: - validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].validateFunc = numericValidator(schema.GetName()) validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int16(obj.(float64)) - field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value) + value, err := strconv.ParseInt(string(obj.(json.Number)), 10, 16) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int16 type field '%s', error: %w", + obj, schema.GetName(), err) + } + field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, int16(value)) field.(*storage.Int16FieldData).NumRows[0]++ return nil } case schemapb.DataType_Int32: - validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].validateFunc = numericValidator(schema.GetName()) validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int32(obj.(float64)) - field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value) + value, err := strconv.ParseInt(string(obj.(json.Number)), 10, 32) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int32 type field '%s', error: %w", + obj, schema.GetName(), err) + } + field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, int32(value)) field.(*storage.Int32FieldData).NumRows[0]++ return nil } case schemapb.DataType_Int64: - validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].validateFunc = numericValidator(schema.GetName()) validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { - value := int64(obj.(float64)) + value, err := strconv.ParseInt(string(obj.(json.Number)), 10, 64) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int64 type field '%s', error: %w", + obj, schema.GetName(), err) + } field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value) field.(*storage.Int64FieldData).NumRows[0]++ return nil @@ -224,27 +250,27 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ if len(vt)*8 != dim { return fmt.Errorf("bit size %d doesn't equal to vector dimension %d of field '%s'", len(vt)*8, dim, schema.GetName()) } + numValidateFunc := numericValidator(schema.GetName()) for i := 0; i < len(vt); i++ { - if e := numericValidator(vt[i]); e != nil { - return fmt.Errorf("%s for binary vector field '%s'", e.Error(), schema.GetName()) - } - - t := int(vt[i].(float64)) - if t > 255 || t < 0 { - return fmt.Errorf("illegal value %d for binary vector field '%s'", t, schema.GetName()) + if e := numValidateFunc(vt[i]); e != nil { + return e } } return nil default: - return fmt.Errorf("%v is not an array for binary vector field '%s'", obj, schema.GetName()) + return fmt.Errorf("'%v' is not an array for binary vector field '%s'", obj, schema.GetName()) } } validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { arr := obj.([]interface{}) for i := 0; i < len(arr); i++ { - value := byte(arr[i].(float64)) - field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, value) + value, err := strconv.ParseUint(string(arr[i].(json.Number)), 10, 8) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for binary vector type field '%s', error: %w", + obj, schema.GetName(), err) + } + field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, byte(value)) } field.(*storage.BinaryVectorFieldData).NumRows[0]++ @@ -263,22 +289,27 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ if len(vt) != dim { return fmt.Errorf("array size %d doesn't equal to vector dimension %d of field '%s'", len(vt), dim, schema.GetName()) } + numValidateFunc := numericValidator(schema.GetName()) for i := 0; i < len(vt); i++ { - if e := numericValidator(vt[i]); e != nil { - return fmt.Errorf("%s for float vector field '%s'", e.Error(), schema.GetName()) + if e := numValidateFunc(vt[i]); e != nil { + return e } } return nil default: - return fmt.Errorf("%v is not an array for float vector field '%s'", obj, schema.GetName()) + return fmt.Errorf("'%v' is not an array for float vector field '%s'", obj, schema.GetName()) } } validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { arr := obj.([]interface{}) for i := 0; i < len(arr); i++ { - value := float32(arr[i].(float64)) - field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, value) + value, err := strconv.ParseFloat(string(arr[i].(json.Number)), 32) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for binary vector type field '%s', error: %w", + obj, schema.GetName(), err) + } + field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, float32(value)) } field.(*storage.FloatVectorFieldData).NumRows[0]++ return nil @@ -290,7 +321,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ case string: return nil default: - return fmt.Errorf("%v is not a string for string type field '%s'", obj, schema.GetName()) + return fmt.Errorf("'%v' is not a string for varchar type field '%s'", obj, schema.GetName()) } } diff --git a/internal/util/importutil/import_util_test.go b/internal/util/importutil/import_util_test.go index 4ad06336b0..eb0d44ee89 100644 --- a/internal/util/importutil/import_util_test.go +++ b/internal/util/importutil/import_util_test.go @@ -17,6 +17,7 @@ package importutil import ( "context" + "encoding/json" "errors" "testing" @@ -180,6 +181,10 @@ func strKeySchema() *schemapb.CollectionSchema { return schema } +func jsonNumber(value string) json.Number { + return json.Number(value) +} + func Test_IsCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -245,7 +250,7 @@ func Test_InitValidators(t *testing.T) { fields := initSegmentData(schema) assert.NotNil(t, fields) - checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) { + checkValidateFunc := func(funcName string, validVal interface{}, invalidVal interface{}) { id := name2ID[funcName] v, ok := validators[id] assert.True(t, ok) @@ -253,6 +258,12 @@ func Test_InitValidators(t *testing.T) { assert.Nil(t, err) err = v.validateFunc(invalidVal) assert.NotNil(t, err) + } + + checkConvertFunc := func(funcName string, validVal interface{}, invalidVal interface{}) { + id := name2ID[funcName] + v, ok := validators[id] + assert.True(t, ok) fieldData := fields[id] preNum := fieldData.RowNum() @@ -260,94 +271,133 @@ func Test_InitValidators(t *testing.T) { assert.Nil(t, err) postNum := fieldData.RowNum() assert.Equal(t, 1, postNum-preNum) + + if invalidVal != nil { + err = v.convertFunc(invalidVal, fieldData) + assert.NotNil(t, err) + } } - // validate functions - var validVal interface{} = true - var invalidVal interface{} = "aa" - checkFunc("field_bool", validVal, invalidVal) + t.Run("check validate functions", func(t *testing.T) { + var validVal interface{} = true + var invalidVal interface{} = "aa" + checkValidateFunc("field_bool", validVal, invalidVal) - validVal = float64(100) - invalidVal = "aa" - checkFunc("field_int8", validVal, invalidVal) - checkFunc("field_int16", validVal, invalidVal) - checkFunc("field_int32", validVal, invalidVal) - checkFunc("field_int64", validVal, invalidVal) - checkFunc("field_float", validVal, invalidVal) - checkFunc("field_double", validVal, invalidVal) + validVal = jsonNumber("100") + invalidVal = "aa" + checkValidateFunc("field_int8", validVal, invalidVal) + checkValidateFunc("field_int16", validVal, invalidVal) + checkValidateFunc("field_int32", validVal, invalidVal) + checkValidateFunc("field_int64", validVal, invalidVal) + checkValidateFunc("field_float", validVal, invalidVal) + checkValidateFunc("field_double", validVal, invalidVal) - validVal = "aa" - invalidVal = 100 - checkFunc("field_string", validVal, invalidVal) + validVal = "aa" + invalidVal = 100 + checkValidateFunc("field_string", validVal, invalidVal) - validVal = []interface{}{float64(100), float64(101)} - invalidVal = "aa" - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(100)} - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(100), float64(101), float64(102)} - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{true, true} - checkFunc("field_binary_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(255), float64(-1)} - checkFunc("field_binary_vector", validVal, invalidVal) + // the binary vector dimension is 16, shoud input 2 uint8 values + validVal = []interface{}{jsonNumber("100"), jsonNumber("101")} + invalidVal = "aa" + checkValidateFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{jsonNumber("100")} + checkValidateFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{jsonNumber("100"), jsonNumber("101"), jsonNumber("102")} + checkValidateFunc("field_binary_vector", validVal, invalidVal) + invalidVal = []interface{}{100, jsonNumber("100")} + checkValidateFunc("field_binary_vector", validVal, invalidVal) - validVal = []interface{}{float64(1), float64(2), float64(3), float64(4)} - invalidVal = true - checkFunc("field_float_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(1), float64(2), float64(3)} - checkFunc("field_float_vector", validVal, invalidVal) - invalidVal = []interface{}{float64(1), float64(2), float64(3), float64(4), float64(5)} - checkFunc("field_float_vector", validVal, invalidVal) - invalidVal = []interface{}{"a", "b", "c", "d"} - checkFunc("field_float_vector", validVal, invalidVal) - - // error cases - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: make([]*schemapb.FieldSchema, 0), - } - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 111, - Name: "field_float_vector", - IsPrimaryKey: false, - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "aa"}, - }, + // the float vector dimension is 4, shoud input 4 float values + validVal = []interface{}{jsonNumber("1"), jsonNumber("2"), jsonNumber("3"), jsonNumber("4")} + invalidVal = true + checkValidateFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{jsonNumber("1"), jsonNumber("2"), jsonNumber("3")} + checkValidateFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{jsonNumber("1"), jsonNumber("2"), jsonNumber("3"), jsonNumber("4"), jsonNumber("5")} + checkValidateFunc("field_float_vector", validVal, invalidVal) + invalidVal = []interface{}{"a", "b", "c", "d"} + checkValidateFunc("field_float_vector", validVal, invalidVal) }) - validators = make(map[storage.FieldID]*Validator) - err = initValidators(schema, validators) - assert.NotNil(t, err) + t.Run("check convert functions", func(t *testing.T) { + var validVal interface{} = true + var invalidVal interface{} + checkConvertFunc("field_bool", validVal, invalidVal) - schema.Fields = make([]*schemapb.FieldSchema, 0) - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 110, - Name: "field_binary_vector", - IsPrimaryKey: false, - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "dim", Value: "aa"}, - }, + validVal = jsonNumber("100") + invalidVal = jsonNumber("128") + checkConvertFunc("field_int8", validVal, invalidVal) + invalidVal = jsonNumber("65536") + checkConvertFunc("field_int16", validVal, invalidVal) + invalidVal = jsonNumber("2147483648") + checkConvertFunc("field_int32", validVal, invalidVal) + invalidVal = jsonNumber("1.2") + checkConvertFunc("field_int64", validVal, invalidVal) + invalidVal = jsonNumber("dummy") + checkConvertFunc("field_float", validVal, invalidVal) + checkConvertFunc("field_double", validVal, invalidVal) + + validVal = "aa" + checkConvertFunc("field_string", validVal, nil) + + // the binary vector dimension is 16, shoud input two uint8 values, each value should between 0~255 + validVal = []interface{}{jsonNumber("100"), jsonNumber("101")} + invalidVal = []interface{}{jsonNumber("100"), jsonNumber("256")} + checkConvertFunc("field_binary_vector", validVal, invalidVal) + + // the float vector dimension is 4, each value should be valid float number + validVal = []interface{}{jsonNumber("1"), jsonNumber("2"), jsonNumber("3"), jsonNumber("4")} + invalidVal = []interface{}{jsonNumber("1"), jsonNumber("2"), jsonNumber("3"), jsonNumber("dummy")} + checkConvertFunc("field_float_vector", validVal, invalidVal) }) - err = initValidators(schema, validators) - assert.NotNil(t, err) + t.Run("init error cases", func(t *testing.T) { + schema = &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: make([]*schemapb.FieldSchema, 0), + } + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 111, + Name: "field_float_vector", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "aa"}, + }, + }) - // unsupported data type - schema.Fields = make([]*schemapb.FieldSchema, 0) - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 110, - Name: "dummy", - IsPrimaryKey: false, - DataType: schemapb.DataType_None, + validators = make(map[storage.FieldID]*Validator) + err = initValidators(schema, validators) + assert.NotNil(t, err) + + schema.Fields = make([]*schemapb.FieldSchema, 0) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 110, + Name: "field_binary_vector", + IsPrimaryKey: false, + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "aa"}, + }, + }) + + err = initValidators(schema, validators) + assert.NotNil(t, err) + + // unsupported data type + schema.Fields = make([]*schemapb.FieldSchema, 0) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 110, + Name: "dummy", + IsPrimaryKey: false, + DataType: schemapb.DataType_None, + }) + + err = initValidators(schema, validators) + assert.NotNil(t, err) }) - - err = initValidators(schema, validators) - assert.NotNil(t, err) } func Test_GetFileNameAndExt(t *testing.T) { diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go index e7977e42f6..e9d3feb368 100644 --- a/internal/util/importutil/json_handler.go +++ b/internal/util/importutil/json_handler.go @@ -17,9 +17,11 @@ package importutil import ( + "encoding/json" "errors" "fmt" "reflect" + "strconv" "go.uber.org/zap" @@ -319,7 +321,15 @@ func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { pk = rowIDBegin + int64(i) } else { value := row[v.primaryKey] - pk = int64(value.(float64)) + // parse the pk from a string + strValue := string(value.(json.Number)) + pk, err = strconv.ParseInt(strValue, 10, 64) + if err != nil { + log.Error("JSON row consumer: failed to parse primary key at the row", + zap.String("value", strValue), zap.Int64("rowNumber", v.rowCounter+int64(i)), zap.Error(err)) + return fmt.Errorf("failed to parse primary key '%s' at the row %d, error: %w", + strValue, v.rowCounter+int64(i), err) + } } hash, err := typeutil.Hash32Int64(pk) diff --git a/internal/util/importutil/json_handler_test.go b/internal/util/importutil/json_handler_test.go index 9264cb11d4..5d7e51460b 100644 --- a/internal/util/importutil/json_handler_test.go +++ b/internal/util/importutil/json_handler_test.go @@ -253,6 +253,15 @@ func Test_JSONRowConsumer(t *testing.T) { assert.Equal(t, shardNum, callTime) assert.Equal(t, 5, totalCount) + + // parse primary key error + reader = strings.NewReader(`{ + "rows":[ + {"field_bool": true, "field_int8": 10, "field_int16": 101, "field_int32": 1001, "field_int64": 0.5, "field_float": 3.14, "field_double": 1.56, "field_string": "hello world", "field_binary_vector": [254, 0], "field_float_vector": [1.1, 1.2, 1.3, 1.4]} + ] + }`) + err = parser.ParseRows(reader, validator) + assert.Error(t, err) } func Test_JSONRowConsumerFlush(t *testing.T) { diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index b7183c4ef0..7082da2746 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -97,6 +97,11 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { dec := json.NewDecoder(r) + // treat number value as a string instead of a float64. + // by default, json lib treat all number values as float64, but if an int64 value + // has more than 15 digits, the value would be incorrect after converting from float64 + dec.UseNumber() + t, err := dec.Token() if err != nil { log.Error("JSON parser: failed to decode the JSON file", zap.Error(err))