fix: more comprehensive check on sparse index and value (#32250)

issue: #29419

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
This commit is contained in:
Buqian Zheng 2024-04-15 16:51:24 +08:00 committed by GitHub
parent 0d849a6c0a
commit c93ae72d92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 3 deletions

View File

@ -1470,10 +1470,16 @@ func ValidateSparseFloatRows(rows ...[]byte) error {
if idx == math.MaxUint32 {
return errors.New("invalid index in sparse float vector: must be less than 2^32-1")
}
if i > 0 && idx < SparseFloatRowIndexAt(row, i-1) {
return errors.New("unsorted indices in sparse float vector")
if i > 0 && idx <= SparseFloatRowIndexAt(row, i-1) {
return errors.New("unsorted or same indices in sparse float vector")
}
val := SparseFloatRowValueAt(row, i)
if err := VerifyFloat(float64(val)); err != nil {
return err
}
if val < 0 {
return errors.New("negative value in sparse float vector")
}
VerifyFloat(float64(SparseFloatRowValueAt(row, i)))
}
}
return nil

View File

@ -2061,6 +2061,42 @@ func TestValidateSparseFloatRows(t *testing.T) {
assert.Error(t, err)
})
t.Run("same index", func(t *testing.T) {
rows := [][]byte{
testutils.CreateSparseFloatRow([]uint32{100, 100, 500}, []float32{1.0, 2.0, 3.0}),
}
err := ValidateSparseFloatRows(rows...)
assert.Error(t, err)
})
t.Run("negative value", func(t *testing.T) {
rows := [][]byte{
testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{-1.0, 2.0, 3.0}),
}
err := ValidateSparseFloatRows(rows...)
assert.Error(t, err)
})
t.Run("invalid value", func(t *testing.T) {
rows := [][]byte{
testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{float32(math.NaN()), 2.0, 3.0}),
}
err := ValidateSparseFloatRows(rows...)
assert.Error(t, err)
rows = [][]byte{
testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{float32(math.Inf(1)), 2.0, 3.0}),
}
err = ValidateSparseFloatRows(rows...)
assert.Error(t, err)
rows = [][]byte{
testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{float32(math.Inf(-1)), 2.0, 3.0}),
}
err = ValidateSparseFloatRows(rows...)
assert.Error(t, err)
})
t.Run("invalid index", func(t *testing.T) {
rows := [][]byte{
testutils.CreateSparseFloatRow([]uint32{3, 5, math.MaxUint32}, []float32{1.0, 2.0, 3.0}),