From c93ae72d92c143e7080038bca811d0469b5eb6f0 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Mon, 15 Apr 2024 16:51:24 +0800 Subject: [PATCH] fix: more comprehensive check on sparse index and value (#32250) issue: #29419 Signed-off-by: Buqian Zheng --- pkg/util/typeutil/schema.go | 12 ++++++++--- pkg/util/typeutil/schema_test.go | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index c036f13aa5..6d27e79ffe 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -1470,10 +1470,16 @@ func ValidateSparseFloatRows(rows ...[]byte) error { if idx == math.MaxUint32 { return errors.New("invalid index in sparse float vector: must be less than 2^32-1") } - if i > 0 && idx < SparseFloatRowIndexAt(row, i-1) { - return errors.New("unsorted indices in sparse float vector") + if i > 0 && idx <= SparseFloatRowIndexAt(row, i-1) { + return errors.New("unsorted or same indices in sparse float vector") + } + val := SparseFloatRowValueAt(row, i) + if err := VerifyFloat(float64(val)); err != nil { + return err + } + if val < 0 { + return errors.New("negative value in sparse float vector") } - VerifyFloat(float64(SparseFloatRowValueAt(row, i))) } } return nil diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index 093da91d7b..063c90bcbc 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -2061,6 +2061,42 @@ func TestValidateSparseFloatRows(t *testing.T) { assert.Error(t, err) }) + t.Run("same index", func(t *testing.T) { + rows := [][]byte{ + testutils.CreateSparseFloatRow([]uint32{100, 100, 500}, []float32{1.0, 2.0, 3.0}), + } + err := ValidateSparseFloatRows(rows...) + assert.Error(t, err) + }) + + t.Run("negative value", func(t *testing.T) { + rows := [][]byte{ + testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{-1.0, 2.0, 3.0}), + } + err := ValidateSparseFloatRows(rows...) + assert.Error(t, err) + }) + + t.Run("invalid value", func(t *testing.T) { + rows := [][]byte{ + testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{float32(math.NaN()), 2.0, 3.0}), + } + err := ValidateSparseFloatRows(rows...) + assert.Error(t, err) + + rows = [][]byte{ + testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{float32(math.Inf(1)), 2.0, 3.0}), + } + err = ValidateSparseFloatRows(rows...) + assert.Error(t, err) + + rows = [][]byte{ + testutils.CreateSparseFloatRow([]uint32{100, 200, 500}, []float32{float32(math.Inf(-1)), 2.0, 3.0}), + } + err = ValidateSparseFloatRows(rows...) + assert.Error(t, err) + }) + t.Run("invalid index", func(t *testing.T) { rows := [][]byte{ testutils.CreateSparseFloatRow([]uint32{3, 5, math.MaxUint32}, []float32{1.0, 2.0, 3.0}),