diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 272447981a..5a696be743 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -1555,22 +1555,34 @@ func CreateSparseFloatRowFromMap(input map[string]interface{}) ([]byte, error) { if ok1 && ok2 { // try format1 - for _, v1 := range jsonIndices { - if num1, suc1 := v1.(int); suc1 { - indices = append(indices, uint32(num1)) - } else { - if num2, suc2 := v1.(float64); suc2 && num2 == float64(int(num2)) { - indices = append(indices, uint32(num2)) + for _, idx := range jsonIndices { + if i1, s1 := idx.(int); s1 { + indices = append(indices, uint32(i1)) + } else if i2, s2 := idx.(float64); s2 && i2 == float64(int(i2)) { + indices = append(indices, uint32(i2)) + } else if i3, s3 := idx.(json.Number); s3 { + if num, err := strconv.ParseUint(i3.String(), 0, 32); err == nil { + indices = append(indices, uint32(num)) } else { - return nil, fmt.Errorf("invalid index type: %v(%s)", v1, reflect.TypeOf(v1)) + return nil, err } + } else { + return nil, fmt.Errorf("invalid indicies type: %v(%s)", idx, reflect.TypeOf(idx)) } } - for _, v2 := range jsonValues { - if num, ok := v2.(float64); ok { - values = append(values, float32(num)) + for _, val := range jsonValues { + if v1, s1 := val.(int); s1 { + values = append(values, float32(v1)) + } else if v2, s2 := val.(float64); s2 { + values = append(values, float32(v2)) + } else if v3, s3 := val.(json.Number); s3 { + if num, err := strconv.ParseFloat(v3.String(), 32); err == nil { + values = append(values, float32(num)) + } else { + return nil, err + } } else { - return nil, fmt.Errorf("invalid value type: %s", reflect.TypeOf(v2)) + return nil, fmt.Errorf("invalid values type: %v(%s)", val, reflect.TypeOf(val)) } } } else if !ok1 && !ok2 { diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index b1e5ec4b83..f487336b94 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -2133,6 +2133,20 @@ func TestParseJsonSparseFloatRow(t *testing.T) { assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{2.0, 1.0, 3.0}), res) }) + t.Run("valid row 3", func(t *testing.T) { + row := map[string]interface{}{"indices": []interface{}{1, 3, 5}, "values": []interface{}{1, 2, 3}} + res, err := CreateSparseFloatRowFromMap(row) + assert.NoError(t, err) + assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{1.0, 2.0, 3.0}), res) + }) + + t.Run("valid row 3", func(t *testing.T) { + row := map[string]interface{}{"indices": []interface{}{math.MaxInt32 + 1}, "values": []interface{}{1.0}} + res, err := CreateSparseFloatRowFromMap(row) + assert.NoError(t, err) + assert.Equal(t, CreateSparseFloatRow([]uint32{math.MaxInt32 + 1}, []float32{1.0}), res) + }) + t.Run("invalid row 1", func(t *testing.T) { row := map[string]interface{}{"indices": []interface{}{1, 3, 5}, "values": []interface{}{1.0, 2.0}} _, err := CreateSparseFloatRowFromMap(row) @@ -2235,6 +2249,20 @@ func TestParseJsonSparseFloatRowBytes(t *testing.T) { assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{2.0, 1.0, 3.0}), res) }) + t.Run("valid row 3", func(t *testing.T) { + row := []byte(`{"indices":[1, 3, 5], "values":[1, 2, 3]}`) + res, err := CreateSparseFloatRowFromJSON(row) + assert.NoError(t, err) + assert.Equal(t, CreateSparseFloatRow([]uint32{1, 3, 5}, []float32{1.0, 2.0, 3.0}), res) + }) + + t.Run("valid row 3", func(t *testing.T) { + row := []byte(`{"indices":[2147483648], "values":[1.0]}`) + res, err := CreateSparseFloatRowFromJSON(row) + assert.NoError(t, err) + assert.Equal(t, CreateSparseFloatRow([]uint32{math.MaxInt32 + 1}, []float32{1.0}), res) + }) + t.Run("invalid row 1", func(t *testing.T) { row := []byte(`{"indices":[1,3,5],"values":[1.0,2.0,3.0`) _, err := CreateSparseFloatRowFromJSON(row)