diff --git a/internal/storage/field_stats.go b/internal/storage/field_stats.go new file mode 100644 index 0000000000..a26e8aa9e1 --- /dev/null +++ b/internal/storage/field_stats.go @@ -0,0 +1,445 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "encoding/json" + "fmt" + + "github.com/bits-and-blooms/bloom/v3" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// FieldStats contains statistics data for any column +// todo: compatible to PrimaryKeyStats +type FieldStats struct { + FieldID int64 `json:"fieldID"` + Type schemapb.DataType `json:"type"` + Max ScalarFieldValue `json:"max"` // for scalar field + Min ScalarFieldValue `json:"min"` // for scalar field + BF *bloom.BloomFilter `json:"bf"` // for scalar field + Centroids []VectorFieldValue `json:"centroids"` // for vector field +} + +// UnmarshalJSON unmarshal bytes to FieldStats +func (stats *FieldStats) UnmarshalJSON(data []byte) error { + var messageMap map[string]*json.RawMessage + err := json.Unmarshal(data, &messageMap) + if err != nil { + return err + } + + if value, ok := messageMap["fieldID"]; ok && value != nil { + err = json.Unmarshal(*messageMap["fieldID"], &stats.FieldID) + if err != nil { + return err + } + } else { + return fmt.Errorf("invalid fieldStats, no fieldID") + } + + stats.Type = schemapb.DataType_Int64 + value, ok := messageMap["type"] + if !ok { + value, ok = messageMap["pkType"] + } + if ok && value != nil { + var typeValue int32 + err = json.Unmarshal(*value, &typeValue) + if err != nil { + return err + } + if typeValue > 0 { + stats.Type = schemapb.DataType(typeValue) + } + } + + isScalarField := false + switch stats.Type { + case schemapb.DataType_Int8: + stats.Max = &Int8FieldValue{} + stats.Min = &Int8FieldValue{} + isScalarField = true + case schemapb.DataType_Int16: + stats.Max = &Int16FieldValue{} + stats.Min = &Int16FieldValue{} + isScalarField = true + case schemapb.DataType_Int32: + stats.Max = &Int32FieldValue{} + stats.Min = &Int32FieldValue{} + isScalarField = true + case schemapb.DataType_Int64: + stats.Max = &Int64FieldValue{} + stats.Min = &Int64FieldValue{} + isScalarField = true + case schemapb.DataType_Float: + stats.Max = &FloatFieldValue{} + stats.Min = &FloatFieldValue{} + isScalarField = true + case schemapb.DataType_Double: + stats.Max = &DoubleFieldValue{} + stats.Min = &DoubleFieldValue{} + isScalarField = true + case schemapb.DataType_String: + stats.Max = &StringFieldValue{} + stats.Min = &StringFieldValue{} + isScalarField = true + case schemapb.DataType_VarChar: + stats.Max = &VarCharFieldValue{} + stats.Min = &VarCharFieldValue{} + isScalarField = true + case schemapb.DataType_FloatVector: + stats.Centroids = []VectorFieldValue{} + isScalarField = false + default: + // unsupported data type + } + + if isScalarField { + if value, ok := messageMap["max"]; ok && value != nil { + err = json.Unmarshal(*messageMap["max"], &stats.Max) + if err != nil { + return err + } + } + if value, ok := messageMap["min"]; ok && value != nil { + err = json.Unmarshal(*messageMap["min"], &stats.Min) + if err != nil { + return err + } + } + // compatible with primaryKeyStats + if maxPkMessage, ok := messageMap["maxPk"]; ok && maxPkMessage != nil { + err = json.Unmarshal(*maxPkMessage, stats.Max) + if err != nil { + return err + } + } + + if minPkMessage, ok := messageMap["minPk"]; ok && minPkMessage != nil { + err = json.Unmarshal(*minPkMessage, stats.Min) + if err != nil { + return err + } + } + + if bfMessage, ok := messageMap["bf"]; ok && bfMessage != nil { + stats.BF = &bloom.BloomFilter{} + err = stats.BF.UnmarshalJSON(*bfMessage) + if err != nil { + return err + } + } + } else { + stats.initCentroids(data, stats.Type) + err = json.Unmarshal(*messageMap["centroids"], &stats.Centroids) + if err != nil { + return err + } + } + + return nil +} + +func (stats *FieldStats) initCentroids(data []byte, dataType schemapb.DataType) { + type FieldStatsAux struct { + FieldID int64 `json:"fieldID"` + Type schemapb.DataType `json:"type"` + Max json.RawMessage `json:"max"` + Min json.RawMessage `json:"min"` + BF *bloom.BloomFilter `json:"bf"` + Centroids []json.RawMessage `json:"centroids"` + } + // Unmarshal JSON into the auxiliary struct + var aux FieldStatsAux + if err := json.Unmarshal(data, &aux); err != nil { + return + } + for i := 0; i < len(aux.Centroids); i++ { + switch dataType { + case schemapb.DataType_FloatVector: + stats.Centroids = append(stats.Centroids, &FloatVectorFieldValue{}) + default: + // other vector datatype + } + } +} + +func (stats *FieldStats) UpdateByMsgs(msgs FieldData) { + switch stats.Type { + case schemapb.DataType_Int8: + data := msgs.(*Int8FieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + b := make([]byte, 8) + for _, int8Value := range data { + pk := NewInt8FieldValue(int8Value) + stats.UpdateMinMax(pk) + common.Endian.PutUint64(b, uint64(int8Value)) + stats.BF.Add(b) + } + case schemapb.DataType_Int16: + data := msgs.(*Int16FieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + b := make([]byte, 8) + for _, int16Value := range data { + pk := NewInt16FieldValue(int16Value) + stats.UpdateMinMax(pk) + common.Endian.PutUint64(b, uint64(int16Value)) + stats.BF.Add(b) + } + case schemapb.DataType_Int32: + data := msgs.(*Int32FieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + b := make([]byte, 8) + for _, int32Value := range data { + pk := NewInt32FieldValue(int32Value) + stats.UpdateMinMax(pk) + common.Endian.PutUint64(b, uint64(int32Value)) + stats.BF.Add(b) + } + case schemapb.DataType_Int64: + data := msgs.(*Int64FieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + b := make([]byte, 8) + for _, int64Value := range data { + pk := NewInt64FieldValue(int64Value) + stats.UpdateMinMax(pk) + common.Endian.PutUint64(b, uint64(int64Value)) + stats.BF.Add(b) + } + case schemapb.DataType_Float: + data := msgs.(*FloatFieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + b := make([]byte, 8) + for _, floatValue := range data { + pk := NewFloatFieldValue(floatValue) + stats.UpdateMinMax(pk) + common.Endian.PutUint64(b, uint64(floatValue)) + stats.BF.Add(b) + } + case schemapb.DataType_Double: + data := msgs.(*DoubleFieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + b := make([]byte, 8) + for _, doubleValue := range data { + pk := NewDoubleFieldValue(doubleValue) + stats.UpdateMinMax(pk) + common.Endian.PutUint64(b, uint64(doubleValue)) + stats.BF.Add(b) + } + case schemapb.DataType_String: + data := msgs.(*StringFieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + for _, str := range data { + pk := NewStringFieldValue(str) + stats.UpdateMinMax(pk) + stats.BF.AddString(str) + } + case schemapb.DataType_VarChar: + data := msgs.(*StringFieldData).Data + // return error: msgs must has one element at least + if len(data) < 1 { + return + } + for _, str := range data { + pk := NewVarCharFieldValue(str) + stats.UpdateMinMax(pk) + stats.BF.AddString(str) + } + default: + // TODO:: + } +} + +func (stats *FieldStats) Update(pk ScalarFieldValue) { + stats.UpdateMinMax(pk) + switch stats.Type { + case schemapb.DataType_Int8: + data := pk.GetValue().(int8) + b := make([]byte, 8) + common.Endian.PutUint64(b, uint64(data)) + stats.BF.Add(b) + case schemapb.DataType_Int16: + data := pk.GetValue().(int16) + b := make([]byte, 8) + common.Endian.PutUint64(b, uint64(data)) + stats.BF.Add(b) + case schemapb.DataType_Int32: + data := pk.GetValue().(int32) + b := make([]byte, 8) + common.Endian.PutUint64(b, uint64(data)) + stats.BF.Add(b) + case schemapb.DataType_Int64: + data := pk.GetValue().(int64) + b := make([]byte, 8) + common.Endian.PutUint64(b, uint64(data)) + stats.BF.Add(b) + case schemapb.DataType_Float: + data := pk.GetValue().(float32) + b := make([]byte, 8) + common.Endian.PutUint64(b, uint64(data)) + stats.BF.Add(b) + case schemapb.DataType_Double: + data := pk.GetValue().(float64) + b := make([]byte, 8) + common.Endian.PutUint64(b, uint64(data)) + stats.BF.Add(b) + case schemapb.DataType_String: + data := pk.GetValue().(string) + stats.BF.AddString(data) + case schemapb.DataType_VarChar: + data := pk.GetValue().(string) + stats.BF.AddString(data) + default: + // todo support vector field + } +} + +// UpdateMinMax update min and max value +func (stats *FieldStats) UpdateMinMax(pk ScalarFieldValue) { + if stats.Min == nil { + stats.Min = pk + } else if stats.Min.GT(pk) { + stats.Min = pk + } + + if stats.Max == nil { + stats.Max = pk + } else if stats.Max.LT(pk) { + stats.Max = pk + } +} + +// SetVectorCentroids update centroids value +func (stats *FieldStats) SetVectorCentroids(centroids ...VectorFieldValue) { + stats.Centroids = centroids +} + +func NewFieldStats(fieldID int64, pkType schemapb.DataType, rowNum int64) (*FieldStats, error) { + if pkType == schemapb.DataType_FloatVector { + return &FieldStats{ + FieldID: fieldID, + Type: pkType, + }, nil + } + return &FieldStats{ + FieldID: fieldID, + Type: pkType, + BF: bloom.NewWithEstimates(uint(rowNum), paramtable.Get().CommonCfg.MaxBloomFalsePositive.GetAsFloat()), + }, nil +} + +// FieldStatsWriter writes stats to buffer +type FieldStatsWriter struct { + buffer []byte +} + +// GetBuffer returns buffer +func (sw *FieldStatsWriter) GetBuffer() []byte { + return sw.buffer +} + +// GenerateList writes Stats slice to buffer +func (sw *FieldStatsWriter) GenerateList(stats []*FieldStats) error { + b, err := json.Marshal(stats) + if err != nil { + return err + } + sw.buffer = b + return nil +} + +// GenerateByData writes data from @msgs with @fieldID to @buffer +func (sw *FieldStatsWriter) GenerateByData(fieldID int64, pkType schemapb.DataType, msgs ...FieldData) error { + statsList := make([]*FieldStats, 0) + for _, msg := range msgs { + stats := &FieldStats{ + FieldID: fieldID, + Type: pkType, + BF: bloom.NewWithEstimates(uint(msg.RowNum()), paramtable.Get().CommonCfg.MaxBloomFalsePositive.GetAsFloat()), + } + + stats.UpdateByMsgs(msg) + statsList = append(statsList, stats) + } + return sw.GenerateList(statsList) +} + +// FieldStatsReader reads stats +type FieldStatsReader struct { + buffer []byte +} + +// SetBuffer sets buffer +func (sr *FieldStatsReader) SetBuffer(buffer []byte) { + sr.buffer = buffer +} + +// GetFieldStatsList returns buffer as FieldStats +func (sr *FieldStatsReader) GetFieldStatsList() ([]*FieldStats, error) { + var statsList []*FieldStats + err := json.Unmarshal(sr.buffer, &statsList) + if err != nil { + // Compatible to PrimaryKey Stats + stats := &FieldStats{} + errNew := json.Unmarshal(sr.buffer, &stats) + if errNew != nil { + return nil, merr.WrapErrParameterInvalid("valid JSON", string(sr.buffer), err.Error()) + } + return []*FieldStats{stats}, nil + } + + return statsList, nil +} + +func DeserializeFieldStats(blob *Blob) ([]*FieldStats, error) { + if len(blob.Value) == 0 { + return []*FieldStats{}, nil + } + sr := &FieldStatsReader{} + sr.SetBuffer(blob.Value) + stats, err := sr.GetFieldStatsList() + if err != nil { + return nil, err + } + return stats, nil +} diff --git a/internal/storage/field_stats_test.go b/internal/storage/field_stats_test.go new file mode 100644 index 0000000000..e169902bf9 --- /dev/null +++ b/internal/storage/field_stats_test.go @@ -0,0 +1,709 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "encoding/json" + "testing" + + "github.com/bits-and-blooms/bloom/v3" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func TestFieldStatsUpdate(t *testing.T) { + fieldStat1, err := NewFieldStats(1, schemapb.DataType_Int8, 2) + assert.NoError(t, err) + fieldStat1.Update(NewInt8FieldValue(1)) + fieldStat1.Update(NewInt8FieldValue(3)) + assert.Equal(t, int8(3), fieldStat1.Max.GetValue()) + assert.Equal(t, int8(1), fieldStat1.Min.GetValue()) + + fieldStat2, err := NewFieldStats(1, schemapb.DataType_Int16, 2) + assert.NoError(t, err) + fieldStat2.Update(NewInt16FieldValue(99)) + fieldStat2.Update(NewInt16FieldValue(201)) + assert.Equal(t, int16(201), fieldStat2.Max.GetValue()) + assert.Equal(t, int16(99), fieldStat2.Min.GetValue()) + + fieldStat3, err := NewFieldStats(1, schemapb.DataType_Int32, 2) + assert.NoError(t, err) + fieldStat3.Update(NewInt32FieldValue(99)) + fieldStat3.Update(NewInt32FieldValue(201)) + assert.Equal(t, int32(201), fieldStat3.Max.GetValue()) + assert.Equal(t, int32(99), fieldStat3.Min.GetValue()) + + fieldStat4, err := NewFieldStats(1, schemapb.DataType_Int64, 2) + assert.NoError(t, err) + fieldStat4.Update(NewInt64FieldValue(99)) + fieldStat4.Update(NewInt64FieldValue(201)) + assert.Equal(t, int64(201), fieldStat4.Max.GetValue()) + assert.Equal(t, int64(99), fieldStat4.Min.GetValue()) + + fieldStat5, err := NewFieldStats(1, schemapb.DataType_Float, 2) + assert.NoError(t, err) + fieldStat5.Update(NewFloatFieldValue(99.0)) + fieldStat5.Update(NewFloatFieldValue(201.0)) + assert.Equal(t, float32(201.0), fieldStat5.Max.GetValue()) + assert.Equal(t, float32(99.0), fieldStat5.Min.GetValue()) + + fieldStat6, err := NewFieldStats(1, schemapb.DataType_Double, 2) + assert.NoError(t, err) + fieldStat6.Update(NewDoubleFieldValue(9.9)) + fieldStat6.Update(NewDoubleFieldValue(20.1)) + assert.Equal(t, float64(20.1), fieldStat6.Max.GetValue()) + assert.Equal(t, float64(9.9), fieldStat6.Min.GetValue()) + + fieldStat7, err := NewFieldStats(2, schemapb.DataType_String, 2) + assert.NoError(t, err) + fieldStat7.Update(NewStringFieldValue("a")) + fieldStat7.Update(NewStringFieldValue("z")) + assert.Equal(t, "z", fieldStat7.Max.GetValue()) + assert.Equal(t, "a", fieldStat7.Min.GetValue()) + + fieldStat8, err := NewFieldStats(2, schemapb.DataType_VarChar, 2) + assert.NoError(t, err) + fieldStat8.Update(NewVarCharFieldValue("a")) + fieldStat8.Update(NewVarCharFieldValue("z")) + assert.Equal(t, "z", fieldStat8.Max.GetValue()) + assert.Equal(t, "a", fieldStat8.Min.GetValue()) +} + +func TestFieldStatsWriter_Int8FieldValue(t *testing.T) { + data := &Int8FieldData{ + Data: []int8{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int8, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewInt8FieldValue(9) + minPk := NewInt8FieldValue(1) + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats.BF.Test(buffer)) + } + + msgs := &Int8FieldData{ + Data: []int8{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int8, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_Int16FieldValue(t *testing.T) { + data := &Int16FieldData{ + Data: []int16{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int16, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewInt16FieldValue(9) + minPk := NewInt16FieldValue(1) + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats.BF.Test(buffer)) + } + + msgs := &Int16FieldData{ + Data: []int16{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int16, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_Int32FieldValue(t *testing.T) { + data := &Int32FieldData{ + Data: []int32{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int32, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewInt32FieldValue(9) + minPk := NewInt32FieldValue(1) + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats.BF.Test(buffer)) + } + + msgs := &Int32FieldData{ + Data: []int32{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int32, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_Int64FieldValue(t *testing.T) { + data := &Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewInt64FieldValue(9) + minPk := NewInt64FieldValue(1) + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats.BF.Test(buffer)) + } + + msgs := &Int64FieldData{ + Data: []int64{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_FloatFieldValue(t *testing.T) { + data := &FloatFieldData{ + Data: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Float, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewFloatFieldValue(9) + minPk := NewFloatFieldValue(1) + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats.BF.Test(buffer)) + } + + msgs := &FloatFieldData{ + Data: []float32{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Float, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_DoubleFieldValue(t *testing.T) { + data := &DoubleFieldData{ + Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Double, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewDoubleFieldValue(9) + minPk := NewDoubleFieldValue(1) + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats.BF.Test(buffer)) + } + + msgs := &DoubleFieldData{ + Data: []float64{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Double, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_StringFieldValue(t *testing.T) { + data := &StringFieldData{ + Data: []string{"bc", "ac", "abd", "cd", "milvus"}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_String, data) + assert.NoError(t, err) + b := sw.GetBuffer() + t.Log(string(b)) + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewStringFieldValue("milvus") + minPk := NewStringFieldValue("abd") + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + for _, id := range data.Data { + assert.True(t, stats.BF.TestString(id)) + } + + msgs := &Int64FieldData{ + Data: []int64{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_VarCharFieldValue(t *testing.T) { + data := &StringFieldData{ + Data: []string{"bc", "ac", "abd", "cd", "milvus"}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_VarChar, data) + assert.NoError(t, err) + b := sw.GetBuffer() + t.Log(string(b)) + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := NewVarCharFieldValue("milvus") + minPk := NewVarCharFieldValue("abd") + assert.Equal(t, true, stats.Max.EQ(maxPk)) + assert.Equal(t, true, stats.Min.EQ(minPk)) + for _, id := range data.Data { + assert.True(t, stats.BF.TestString(id)) + } + + msgs := &Int64FieldData{ + Data: []int64{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsWriter_BF(t *testing.T) { + value := make([]int64, 1000000) + for i := 0; i < 1000000; i++ { + value[i] = int64(i) + } + data := &Int64FieldData{ + Data: value, + } + t.Log(data.RowNum()) + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + assert.NoError(t, err) + + sr := &FieldStatsReader{} + sr.SetBuffer(sw.GetBuffer()) + statsList, err := sr.GetFieldStatsList() + assert.NoError(t, err) + stats := statsList[0] + buf := make([]byte, 8) + + for i := 0; i < 1000000; i++ { + common.Endian.PutUint64(buf, uint64(i)) + assert.True(t, stats.BF.Test(buf)) + } + + common.Endian.PutUint64(buf, uint64(1000001)) + assert.False(t, stats.BF.Test(buf)) + + assert.True(t, stats.Min.EQ(NewInt64FieldValue(0))) + assert.True(t, stats.Max.EQ(NewInt64FieldValue(999999))) +} + +func TestFieldStatsWriter_UpgradePrimaryKey(t *testing.T) { + data := &Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + + stats := &PrimaryKeyStats{ + FieldID: common.RowIDField, + Min: 1, + Max: 9, + BF: bloom.NewWithEstimates(100000, 0.05), + } + + b := make([]byte, 8) + for _, int64Value := range data.Data { + common.Endian.PutUint64(b, uint64(int64Value)) + stats.BF.Add(b) + } + blob, err := json.Marshal(stats) + assert.NoError(t, err) + sr := &FieldStatsReader{} + sr.SetBuffer(blob) + unmarshalledStats, err := sr.GetFieldStatsList() + assert.NoError(t, err) + maxPk := &Int64FieldValue{ + Value: 9, + } + minPk := &Int64FieldValue{ + Value: 1, + } + assert.Equal(t, true, unmarshalledStats[0].Max.EQ(maxPk)) + assert.Equal(t, true, unmarshalledStats[0].Min.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, unmarshalledStats[0].BF.Test(buffer)) + } +} + +func TestDeserializeFieldStatsFailed(t *testing.T) { + t.Run("empty field stats", func(t *testing.T) { + blob := &Blob{ + Value: []byte{}, + } + + _, err := DeserializeFieldStats(blob) + assert.NoError(t, err) + }) + + t.Run("invalid field stats", func(t *testing.T) { + blob := &Blob{ + Value: []byte("abc"), + } + + _, err := DeserializeFieldStats(blob) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + }) + + t.Run("valid field stats", func(t *testing.T) { + blob := &Blob{ + Value: []byte("[{\"fieldID\":1,\"max\":10, \"min\":1}]"), + } + _, err := DeserializeFieldStats(blob) + assert.NoError(t, err) + }) +} + +func TestDeserializeFieldStats(t *testing.T) { + t.Run("empty field stats", func(t *testing.T) { + blob := &Blob{ + Value: []byte{}, + } + + _, err := DeserializeFieldStats(blob) + assert.NoError(t, err) + }) + + t.Run("invalid field stats, not valid json", func(t *testing.T) { + blob := &Blob{ + Value: []byte("abc"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, no fieldID", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"field\":\"a\"}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, invalid fieldID", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":\"a\"}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, invalid type", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"type\":\"a\"}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, invalid type", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"type\":\"a\"}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, invalid max int64", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"max\":\"a\"}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, invalid min int64", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"min\":\"a\"}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, invalid max varchar", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"type\":21,\"max\":2}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("invalid field stats, invalid min varchar", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"type\":21,\"min\":1}"), + } + _, err := DeserializeFieldStats(blob) + assert.Error(t, err) + }) + + t.Run("valid int64 field stats", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"max\":10, \"min\":1}"), + } + _, err := DeserializeFieldStats(blob) + assert.NoError(t, err) + }) + + t.Run("valid varchar field stats", func(t *testing.T) { + blob := &Blob{ + Value: []byte("{\"fieldID\":1,\"type\":21,\"max\":\"z\", \"min\":\"a\"}"), + } + _, err := DeserializeFieldStats(blob) + assert.NoError(t, err) + }) +} + +func TestCompatible_ReadPrimaryKeyStatsWithFieldStatsReader(t *testing.T) { + data := &Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &StatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + stats, err := sr.GetFieldStatsList() + assert.NoError(t, err) + maxPk := &Int64FieldValue{ + Value: 9, + } + minPk := &Int64FieldValue{ + Value: 1, + } + assert.Equal(t, true, stats[0].Max.EQ(maxPk)) + assert.Equal(t, true, stats[0].Min.EQ(minPk)) + assert.Equal(t, schemapb.DataType_Int64.String(), stats[0].Type.String()) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats[0].BF.Test(buffer)) + } + + msgs := &Int64FieldData{ + Data: []int64{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, msgs) + assert.NoError(t, err) +} + +func TestFieldStatsUnMarshal(t *testing.T) { + t.Run("fail", func(t *testing.T) { + stats, err := NewFieldStats(1, schemapb.DataType_Int64, 1) + assert.NoError(t, err) + err = stats.UnmarshalJSON([]byte("{\"fieldID\":1,\"max\":10, }")) + assert.Error(t, err) + err = stats.UnmarshalJSON([]byte("{\"fieldID\":1,\"max\":10, \"maxPk\":\"A\"}")) + assert.Error(t, err) + err = stats.UnmarshalJSON([]byte("{\"fieldID\":1,\"max\":10, \"maxPk\":10, \"minPk\": \"b\"}")) + assert.Error(t, err) + err = stats.UnmarshalJSON([]byte("{\"fieldID\":1,\"max\":10, \"maxPk\":10, \"minPk\": 1, \"bf\": \"2\"}")) + assert.Error(t, err) + }) + + t.Run("succeed", func(t *testing.T) { + int8stats, err := NewFieldStats(1, schemapb.DataType_Int8, 1) + assert.NoError(t, err) + err = int8stats.UnmarshalJSON([]byte("{\"type\":2, \"fieldID\":1,\"max\":10, \"min\": 1}")) + assert.NoError(t, err) + + int16stats, err := NewFieldStats(1, schemapb.DataType_Int16, 1) + assert.NoError(t, err) + err = int16stats.UnmarshalJSON([]byte("{\"type\":3, \"fieldID\":1,\"max\":10, \"min\": 1}")) + assert.NoError(t, err) + + int32stats, err := NewFieldStats(1, schemapb.DataType_Int32, 1) + assert.NoError(t, err) + err = int32stats.UnmarshalJSON([]byte("{\"type\":4, \"fieldID\":1,\"max\":10, \"min\": 1}")) + assert.NoError(t, err) + + int64stats, err := NewFieldStats(1, schemapb.DataType_Int64, 1) + assert.NoError(t, err) + err = int64stats.UnmarshalJSON([]byte("{\"type\":5, \"fieldID\":1,\"max\":10, \"min\": 1}")) + assert.NoError(t, err) + + floatstats, err := NewFieldStats(1, schemapb.DataType_Float, 1) + assert.NoError(t, err) + err = floatstats.UnmarshalJSON([]byte("{\"type\":10, \"fieldID\":1,\"max\":10.0, \"min\": 1.2}")) + assert.NoError(t, err) + + doublestats, err := NewFieldStats(1, schemapb.DataType_Double, 1) + assert.NoError(t, err) + err = doublestats.UnmarshalJSON([]byte("{\"type\":11, \"fieldID\":1,\"max\":10.0, \"min\": 1.2}")) + assert.NoError(t, err) + }) +} + +func TestCompatible_ReadFieldStatsWithPrimaryKeyStatsReader(t *testing.T) { + data := &Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, data) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &StatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetPrimaryKeyStatsList() + assert.NoError(t, err) + stats := statsList[0] + maxPk := &Int64PrimaryKey{ + Value: 9, + } + minPk := &Int64PrimaryKey{ + Value: 1, + } + assert.Equal(t, true, stats.MaxPk.EQ(maxPk)) + assert.Equal(t, true, stats.MinPk.EQ(minPk)) + buffer := make([]byte, 8) + for _, id := range data.Data { + common.Endian.PutUint64(buffer, uint64(id)) + assert.True(t, stats.BF.Test(buffer)) + } + + msgs := &Int64FieldData{ + Data: []int64{}, + } + err = sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, msgs) + assert.NoError(t, err) +} + +func TestMultiFieldStats(t *testing.T) { + pkData := &Int64FieldData{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + partitionKeyData := &Int64FieldData{ + Data: []int64{1, 10, 21, 31, 41, 51, 61, 71, 81}, + } + + sw := &FieldStatsWriter{} + err := sw.GenerateByData(common.RowIDField, schemapb.DataType_Int64, pkData, partitionKeyData) + assert.NoError(t, err) + b := sw.GetBuffer() + + sr := &FieldStatsReader{} + sr.SetBuffer(b) + statsList, err := sr.GetFieldStatsList() + assert.Equal(t, 2, len(statsList)) + assert.NoError(t, err) + + pkStats := statsList[0] + maxPk := NewInt64FieldValue(9) + minPk := NewInt64FieldValue(1) + assert.Equal(t, true, pkStats.Max.EQ(maxPk)) + assert.Equal(t, true, pkStats.Min.EQ(minPk)) + + partitionKeyStats := statsList[1] + maxPk2 := NewInt64FieldValue(81) + minPk2 := NewInt64FieldValue(1) + assert.Equal(t, true, partitionKeyStats.Max.EQ(maxPk2)) + assert.Equal(t, true, partitionKeyStats.Min.EQ(minPk2)) +} + +func TestVectorFieldStatsMarshal(t *testing.T) { + stats, err := NewFieldStats(1, schemapb.DataType_FloatVector, 1) + assert.NoError(t, err) + centroid := NewFloatVectorFieldValue([]float32{1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0}) + stats.SetVectorCentroids(centroid) + + bytes, err := json.Marshal(stats) + assert.NoError(t, err) + + stats2, err := NewFieldStats(1, schemapb.DataType_FloatVector, 1) + assert.NoError(t, err) + stats2.UnmarshalJSON(bytes) + assert.Equal(t, 1, len(stats2.Centroids)) + assert.ElementsMatch(t, []VectorFieldValue{centroid}, stats2.Centroids) + + stats3, err := NewFieldStats(1, schemapb.DataType_FloatVector, 2) + assert.NoError(t, err) + centroid2 := NewFloatVectorFieldValue([]float32{9.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0}) + stats3.SetVectorCentroids(centroid, centroid2) + + bytes2, err := json.Marshal(stats3) + assert.NoError(t, err) + + stats4, err := NewFieldStats(1, schemapb.DataType_FloatVector, 2) + assert.NoError(t, err) + stats4.UnmarshalJSON(bytes2) + assert.Equal(t, 2, len(stats4.Centroids)) + assert.ElementsMatch(t, []VectorFieldValue{centroid, centroid2}, stats4.Centroids) +} diff --git a/internal/storage/field_value.go b/internal/storage/field_value.go new file mode 100644 index 0000000000..d9f50cb6e3 --- /dev/null +++ b/internal/storage/field_value.go @@ -0,0 +1,1015 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" +) + +type ScalarFieldValue interface { + GT(key ScalarFieldValue) bool + GE(key ScalarFieldValue) bool + LT(key ScalarFieldValue) bool + LE(key ScalarFieldValue) bool + EQ(key ScalarFieldValue) bool + MarshalJSON() ([]byte, error) + UnmarshalJSON(data []byte) error + SetValue(interface{}) error + GetValue() interface{} + Type() schemapb.DataType + Size() int64 +} + +// DataType_Int8 +type Int8FieldValue struct { + Value int8 `json:"value"` +} + +func NewInt8FieldValue(v int8) *Int8FieldValue { + return &Int8FieldValue{ + Value: v, + } +} + +func (ifv *Int8FieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int8FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value > v.Value { + return true + } + + return false +} + +func (ifv *Int8FieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int8FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value >= v.Value { + return true + } + + return false +} + +func (ifv *Int8FieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int8FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ifv.Value < v.Value { + return true + } + + return false +} + +func (ifv *Int8FieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int8FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value <= v.Value { + return true + } + + return false +} + +func (ifv *Int8FieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*Int8FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value == v.Value { + return true + } + + return false +} + +func (ifv *Int8FieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ifv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ifv *Int8FieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ifv.Value) + if err != nil { + return err + } + + return nil +} + +func (ifv *Int8FieldValue) SetValue(data interface{}) error { + value, ok := data.(int8) + if !ok { + log.Warn("wrong type value when setValue for Int64FieldValue") + return fmt.Errorf("wrong type value when setValue for Int64FieldValue") + } + + ifv.Value = value + return nil +} + +func (ifv *Int8FieldValue) Type() schemapb.DataType { + return schemapb.DataType_Int8 +} + +func (ifv *Int8FieldValue) GetValue() interface{} { + return ifv.Value +} + +func (ifv *Int8FieldValue) Size() int64 { + return 2 +} + +// DataType_Int16 +type Int16FieldValue struct { + Value int16 `json:"value"` +} + +func NewInt16FieldValue(v int16) *Int16FieldValue { + return &Int16FieldValue{ + Value: v, + } +} + +func (ifv *Int16FieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int16FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value > v.Value { + return true + } + + return false +} + +func (ifv *Int16FieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int16FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value >= v.Value { + return true + } + + return false +} + +func (ifv *Int16FieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int16FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ifv.Value < v.Value { + return true + } + + return false +} + +func (ifv *Int16FieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int16FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value <= v.Value { + return true + } + + return false +} + +func (ifv *Int16FieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*Int16FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value == v.Value { + return true + } + + return false +} + +func (ifv *Int16FieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ifv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ifv *Int16FieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ifv.Value) + if err != nil { + return err + } + + return nil +} + +func (ifv *Int16FieldValue) SetValue(data interface{}) error { + value, ok := data.(int16) + if !ok { + log.Warn("wrong type value when setValue for Int64FieldValue") + return fmt.Errorf("wrong type value when setValue for Int64FieldValue") + } + + ifv.Value = value + return nil +} + +func (ifv *Int16FieldValue) Type() schemapb.DataType { + return schemapb.DataType_Int16 +} + +func (ifv *Int16FieldValue) GetValue() interface{} { + return ifv.Value +} + +func (ifv *Int16FieldValue) Size() int64 { + return 4 +} + +// DataType_Int32 +type Int32FieldValue struct { + Value int32 `json:"value"` +} + +func NewInt32FieldValue(v int32) *Int32FieldValue { + return &Int32FieldValue{ + Value: v, + } +} + +func (ifv *Int32FieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int32FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value > v.Value { + return true + } + + return false +} + +func (ifv *Int32FieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int32FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value >= v.Value { + return true + } + + return false +} + +func (ifv *Int32FieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int32FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ifv.Value < v.Value { + return true + } + + return false +} + +func (ifv *Int32FieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int32FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value <= v.Value { + return true + } + + return false +} + +func (ifv *Int32FieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*Int32FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value == v.Value { + return true + } + + return false +} + +func (ifv *Int32FieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ifv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ifv *Int32FieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ifv.Value) + if err != nil { + return err + } + + return nil +} + +func (ifv *Int32FieldValue) SetValue(data interface{}) error { + value, ok := data.(int32) + if !ok { + log.Warn("wrong type value when setValue for Int64FieldValue") + return fmt.Errorf("wrong type value when setValue for Int64FieldValue") + } + + ifv.Value = value + return nil +} + +func (ifv *Int32FieldValue) Type() schemapb.DataType { + return schemapb.DataType_Int32 +} + +func (ifv *Int32FieldValue) GetValue() interface{} { + return ifv.Value +} + +func (ifv *Int32FieldValue) Size() int64 { + return 8 +} + +// DataType_Int64 +type Int64FieldValue struct { + Value int64 `json:"value"` +} + +func NewInt64FieldValue(v int64) *Int64FieldValue { + return &Int64FieldValue{ + Value: v, + } +} + +func (ifv *Int64FieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int64FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value > v.Value { + return true + } + + return false +} + +func (ifv *Int64FieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int64FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value >= v.Value { + return true + } + + return false +} + +func (ifv *Int64FieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*Int64FieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ifv.Value < v.Value { + return true + } + + return false +} + +func (ifv *Int64FieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*Int64FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value <= v.Value { + return true + } + + return false +} + +func (ifv *Int64FieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*Int64FieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value == v.Value { + return true + } + + return false +} + +func (ifv *Int64FieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ifv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ifv *Int64FieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ifv.Value) + if err != nil { + return err + } + + return nil +} + +func (ifv *Int64FieldValue) SetValue(data interface{}) error { + value, ok := data.(int64) + if !ok { + log.Warn("wrong type value when setValue for Int64FieldValue") + return fmt.Errorf("wrong type value when setValue for Int64FieldValue") + } + + ifv.Value = value + return nil +} + +func (ifv *Int64FieldValue) Type() schemapb.DataType { + return schemapb.DataType_Int64 +} + +func (ifv *Int64FieldValue) GetValue() interface{} { + return ifv.Value +} + +func (ifv *Int64FieldValue) Size() int64 { + // 8 + reflect.ValueOf(Int64FieldValue).Type().Size() + return 16 +} + +// DataType_Float +type FloatFieldValue struct { + Value float32 `json:"value"` +} + +func NewFloatFieldValue(v float32) *FloatFieldValue { + return &FloatFieldValue{ + Value: v, + } +} + +func (ifv *FloatFieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*FloatFieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value > v.Value { + return true + } + + return false +} + +func (ifv *FloatFieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*FloatFieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value >= v.Value { + return true + } + + return false +} + +func (ifv *FloatFieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*FloatFieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ifv.Value < v.Value { + return true + } + + return false +} + +func (ifv *FloatFieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*FloatFieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value <= v.Value { + return true + } + + return false +} + +func (ifv *FloatFieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*FloatFieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value == v.Value { + return true + } + + return false +} + +func (ifv *FloatFieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ifv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ifv *FloatFieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ifv.Value) + if err != nil { + return err + } + + return nil +} + +func (ifv *FloatFieldValue) SetValue(data interface{}) error { + value, ok := data.(float32) + if !ok { + log.Warn("wrong type value when setValue for FloatFieldValue") + return fmt.Errorf("wrong type value when setValue for FloatFieldValue") + } + + ifv.Value = value + return nil +} + +func (ifv *FloatFieldValue) Type() schemapb.DataType { + return schemapb.DataType_Float +} + +func (ifv *FloatFieldValue) GetValue() interface{} { + return ifv.Value +} + +func (ifv *FloatFieldValue) Size() int64 { + return 8 +} + +// DataType_Double +type DoubleFieldValue struct { + Value float64 `json:"value"` +} + +func NewDoubleFieldValue(v float64) *DoubleFieldValue { + return &DoubleFieldValue{ + Value: v, + } +} + +func (ifv *DoubleFieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*DoubleFieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value > v.Value { + return true + } + + return false +} + +func (ifv *DoubleFieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*DoubleFieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + if ifv.Value >= v.Value { + return true + } + + return false +} + +func (ifv *DoubleFieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*DoubleFieldValue) + if !ok { + log.Warn("type of compared pk is not int64") + return false + } + + if ifv.Value < v.Value { + return true + } + + return false +} + +func (ifv *DoubleFieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*DoubleFieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value <= v.Value { + return true + } + + return false +} + +func (ifv *DoubleFieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*DoubleFieldValue) + if !ok { + log.Warn("type of compared obj is not int64") + return false + } + + if ifv.Value == v.Value { + return true + } + + return false +} + +func (ifv *DoubleFieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ifv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ifv *DoubleFieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ifv.Value) + if err != nil { + return err + } + + return nil +} + +func (ifv *DoubleFieldValue) SetValue(data interface{}) error { + value, ok := data.(float64) + if !ok { + log.Warn("wrong type value when setValue for DoubleFieldValue") + return fmt.Errorf("wrong type value when setValue for DoubleFieldValue") + } + + ifv.Value = value + return nil +} + +func (ifv *DoubleFieldValue) Type() schemapb.DataType { + return schemapb.DataType_Double +} + +func (ifv *DoubleFieldValue) GetValue() interface{} { + return ifv.Value +} + +func (ifv *DoubleFieldValue) Size() int64 { + return 16 +} + +type StringFieldValue struct { + Value string `json:"value"` +} + +func NewStringFieldValue(v string) *StringFieldValue { + return &StringFieldValue{ + Value: v, + } +} + +func (sfv *StringFieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*StringFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + + return strings.Compare(sfv.Value, v.Value) > 0 +} + +func (sfv *StringFieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*StringFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(sfv.Value, v.Value) >= 0 +} + +func (sfv *StringFieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*StringFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(sfv.Value, v.Value) < 0 +} + +func (sfv *StringFieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*StringFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(sfv.Value, v.Value) <= 0 +} + +func (sfv *StringFieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*StringFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(sfv.Value, v.Value) == 0 +} + +func (sfv *StringFieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(sfv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (sfv *StringFieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &sfv.Value) + if err != nil { + return err + } + + return nil +} + +func (sfv *StringFieldValue) SetValue(data interface{}) error { + value, ok := data.(string) + if !ok { + return fmt.Errorf("wrong type value when setValue for StringFieldValue") + } + + sfv.Value = value + return nil +} + +func (sfv *StringFieldValue) GetValue() interface{} { + return sfv.Value +} + +func (sfv *StringFieldValue) Type() schemapb.DataType { + return schemapb.DataType_String +} + +func (sfv *StringFieldValue) Size() int64 { + return int64(8*len(sfv.Value) + 8) +} + +type VarCharFieldValue struct { + Value string `json:"value"` +} + +func NewVarCharFieldValue(v string) *VarCharFieldValue { + return &VarCharFieldValue{ + Value: v, + } +} + +func (vcfv *VarCharFieldValue) GT(obj ScalarFieldValue) bool { + v, ok := obj.(*VarCharFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + + return strings.Compare(vcfv.Value, v.Value) > 0 +} + +func (vcfv *VarCharFieldValue) GE(obj ScalarFieldValue) bool { + v, ok := obj.(*VarCharFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(vcfv.Value, v.Value) >= 0 +} + +func (vcfv *VarCharFieldValue) LT(obj ScalarFieldValue) bool { + v, ok := obj.(*VarCharFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(vcfv.Value, v.Value) < 0 +} + +func (vcfv *VarCharFieldValue) LE(obj ScalarFieldValue) bool { + v, ok := obj.(*VarCharFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(vcfv.Value, v.Value) <= 0 +} + +func (vcfv *VarCharFieldValue) EQ(obj ScalarFieldValue) bool { + v, ok := obj.(*VarCharFieldValue) + if !ok { + log.Warn("type of compared obj is not varchar") + return false + } + return strings.Compare(vcfv.Value, v.Value) == 0 +} + +func (vcfv *VarCharFieldValue) SetValue(data interface{}) error { + value, ok := data.(string) + if !ok { + return fmt.Errorf("wrong type value when setValue for StringFieldValue") + } + + vcfv.Value = value + return nil +} + +func (vcfv *VarCharFieldValue) GetValue() interface{} { + return vcfv.Value +} + +func (vcfv *VarCharFieldValue) Type() schemapb.DataType { + return schemapb.DataType_VarChar +} + +func (vcfv *VarCharFieldValue) Size() int64 { + return int64(8*len(vcfv.Value) + 8) +} + +func (vcfv *VarCharFieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(vcfv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (vcfv *VarCharFieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &vcfv.Value) + if err != nil { + return err + } + + return nil +} + +type VectorFieldValue interface { + MarshalJSON() ([]byte, error) + UnmarshalJSON(data []byte) error + SetValue(interface{}) error + GetValue() interface{} + Type() schemapb.DataType + Size() int64 +} + +var _ VectorFieldValue = (*FloatVectorFieldValue)(nil) + +type FloatVectorFieldValue struct { + Value []float32 `json:"value"` +} + +func NewFloatVectorFieldValue(v []float32) *FloatVectorFieldValue { + return &FloatVectorFieldValue{ + Value: v, + } +} + +func (ifv *FloatVectorFieldValue) MarshalJSON() ([]byte, error) { + ret, err := json.Marshal(ifv.Value) + if err != nil { + return nil, err + } + + return ret, nil +} + +func (ifv *FloatVectorFieldValue) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &ifv.Value) + if err != nil { + return err + } + + return nil +} + +func (ifv *FloatVectorFieldValue) SetValue(data interface{}) error { + value, ok := data.([]float32) + if !ok { + log.Warn("wrong type value when setValue for FloatVectorFieldValue") + return fmt.Errorf("wrong type value when setValue for FloatVectorFieldValue") + } + + ifv.Value = value + return nil +} + +func (ifv *FloatVectorFieldValue) Type() schemapb.DataType { + return schemapb.DataType_FloatVector +} + +func (ifv *FloatVectorFieldValue) GetValue() interface{} { + return ifv.Value +} + +func (ifv *FloatVectorFieldValue) Size() int64 { + return int64(len(ifv.Value) * 8) +} diff --git a/internal/storage/field_value_test.go b/internal/storage/field_value_test.go new file mode 100644 index 0000000000..0c24c70c2b --- /dev/null +++ b/internal/storage/field_value_test.go @@ -0,0 +1,353 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVarCharFieldValue(t *testing.T) { + pk := NewVarCharFieldValue("milvus") + + testPk := NewVarCharFieldValue("milvus") + + // test GE + assert.Equal(t, true, pk.GE(testPk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + err := testPk.SetValue(1.0) + assert.Error(t, err) + + // test GT + err = testPk.SetValue("bivlus") + assert.NoError(t, err) + assert.Equal(t, true, pk.GT(testPk)) + assert.Equal(t, false, testPk.GT(pk)) + + // test LT + err = testPk.SetValue("mivlut") + assert.NoError(t, err) + assert.Equal(t, true, pk.LT(testPk)) + assert.Equal(t, false, testPk.LT(pk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &VarCharFieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestInt64FieldValue(t *testing.T) { + pk := NewInt64FieldValue(100) + + testPk := NewInt64FieldValue(100) + // test GE + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, true, testPk.GE(pk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, true, testPk.LE(pk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + err := testPk.SetValue(1.0) + assert.Error(t, err) + + // test GT + err = testPk.SetValue(int64(10)) + assert.NoError(t, err) + assert.Equal(t, true, pk.GT(testPk)) + assert.Equal(t, false, testPk.GT(pk)) + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, false, testPk.GE(pk)) + + // test LT + err = testPk.SetValue(int64(200)) + assert.NoError(t, err) + assert.Equal(t, true, pk.LT(testPk)) + assert.Equal(t, false, testPk.LT(pk)) + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, false, testPk.LE(pk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &Int64FieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestInt8FieldValue(t *testing.T) { + pk := NewInt8FieldValue(20) + + testPk := NewInt8FieldValue(20) + // test GE + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, true, testPk.GE(pk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, true, testPk.LE(pk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + err := testPk.SetValue(1.0) + assert.Error(t, err) + + // test GT + err = testPk.SetValue(int8(10)) + assert.NoError(t, err) + assert.Equal(t, true, pk.GT(testPk)) + assert.Equal(t, false, testPk.GT(pk)) + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, false, testPk.GE(pk)) + + // test LT + err = testPk.SetValue(int8(30)) + assert.NoError(t, err) + assert.Equal(t, true, pk.LT(testPk)) + assert.Equal(t, false, testPk.LT(pk)) + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, false, testPk.LE(pk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &Int8FieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestInt16FieldValue(t *testing.T) { + pk := NewInt16FieldValue(100) + + testPk := NewInt16FieldValue(100) + // test GE + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, true, testPk.GE(pk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, true, testPk.LE(pk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + err := testPk.SetValue(1.0) + assert.Error(t, err) + + // test GT + err = testPk.SetValue(int16(10)) + assert.NoError(t, err) + assert.Equal(t, true, pk.GT(testPk)) + assert.Equal(t, false, testPk.GT(pk)) + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, false, testPk.GE(pk)) + + // test LT + err = testPk.SetValue(int16(200)) + assert.NoError(t, err) + assert.Equal(t, true, pk.LT(testPk)) + assert.Equal(t, false, testPk.LT(pk)) + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, false, testPk.LE(pk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &Int16FieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestInt32FieldValue(t *testing.T) { + pk := NewInt32FieldValue(100) + + testPk := NewInt32FieldValue(100) + // test GE + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, true, testPk.GE(pk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, true, testPk.LE(pk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + err := testPk.SetValue(1.0) + assert.Error(t, err) + + // test GT + err = testPk.SetValue(int32(10)) + assert.NoError(t, err) + assert.Equal(t, true, pk.GT(testPk)) + assert.Equal(t, false, testPk.GT(pk)) + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, false, testPk.GE(pk)) + + // test LT + err = testPk.SetValue(int32(200)) + assert.NoError(t, err) + assert.Equal(t, true, pk.LT(testPk)) + assert.Equal(t, false, testPk.LT(pk)) + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, false, testPk.LE(pk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &Int32FieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestFloatFieldValue(t *testing.T) { + pk := NewFloatFieldValue(100) + + testPk := NewFloatFieldValue(100) + // test GE + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, true, testPk.GE(pk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, true, testPk.LE(pk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + + err := testPk.SetValue(float32(1.0)) + assert.NoError(t, err) + // test GT + err = testPk.SetValue(float32(10)) + assert.NoError(t, err) + assert.Equal(t, true, pk.GT(testPk)) + assert.Equal(t, false, testPk.GT(pk)) + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, false, testPk.GE(pk)) + // test LT + err = testPk.SetValue(float32(200)) + assert.NoError(t, err) + assert.Equal(t, true, pk.LT(testPk)) + assert.Equal(t, false, testPk.LT(pk)) + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, false, testPk.LE(pk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &FloatFieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestDoubleFieldValue(t *testing.T) { + pk := NewDoubleFieldValue(100) + + testPk := NewDoubleFieldValue(100) + // test GE + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, true, testPk.GE(pk)) + // test LE + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, true, testPk.LE(pk)) + // test EQ + assert.Equal(t, true, pk.EQ(testPk)) + // test GT + err := testPk.SetValue(float64(10)) + assert.NoError(t, err) + assert.Equal(t, true, pk.GT(testPk)) + assert.Equal(t, false, testPk.GT(pk)) + assert.Equal(t, true, pk.GE(testPk)) + assert.Equal(t, false, testPk.GE(pk)) + // test LT + err = testPk.SetValue(float64(200)) + assert.NoError(t, err) + assert.Equal(t, true, pk.LT(testPk)) + assert.Equal(t, false, testPk.LT(pk)) + assert.Equal(t, true, pk.LE(testPk)) + assert.Equal(t, false, testPk.LE(pk)) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &DoubleFieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} + +func TestFieldValueSize(t *testing.T) { + vcf := NewVarCharFieldValue("milvus") + assert.Equal(t, int64(56), vcf.Size()) + + stf := NewStringFieldValue("milvus") + assert.Equal(t, int64(56), stf.Size()) + + int8f := NewInt8FieldValue(100) + assert.Equal(t, int64(2), int8f.Size()) + + int16f := NewInt16FieldValue(100) + assert.Equal(t, int64(4), int16f.Size()) + + int32f := NewInt32FieldValue(100) + assert.Equal(t, int64(8), int32f.Size()) + + int64f := NewInt64FieldValue(100) + assert.Equal(t, int64(16), int64f.Size()) + + floatf := NewFloatFieldValue(float32(10.7)) + assert.Equal(t, int64(8), floatf.Size()) + + doublef := NewDoubleFieldValue(float64(10.7)) + assert.Equal(t, int64(16), doublef.Size()) +} + +func TestFloatVectorFieldValue(t *testing.T) { + pk := NewFloatVectorFieldValue([]float32{1.0, 2.0, 3.0, 4.0}) + + t.Run("unmarshal", func(t *testing.T) { + blob, err := json.Marshal(pk) + assert.NoError(t, err) + + unmarshalledPk := &FloatVectorFieldValue{} + err = json.Unmarshal(blob, unmarshalledPk) + assert.NoError(t, err) + assert.Equal(t, pk.Value, unmarshalledPk.Value) + }) +} diff --git a/internal/storage/partition_stats.go b/internal/storage/partition_stats.go new file mode 100644 index 0000000000..6f55675e1d --- /dev/null +++ b/internal/storage/partition_stats.go @@ -0,0 +1,71 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import "encoding/json" + +type SegmentStats struct { + FieldStats []FieldStats `json:"fieldStats"` +} + +type PartitionStatsSnapshot struct { + SegmentStats map[UniqueID]SegmentStats `json:"segmentStats"` + Version int64 +} + +func NewPartitionStatsSnapshot() *PartitionStatsSnapshot { + return &PartitionStatsSnapshot{ + SegmentStats: make(map[UniqueID]SegmentStats, 0), + } +} + +func (ps *PartitionStatsSnapshot) GetVersion() int64 { + return ps.Version +} + +func (ps *PartitionStatsSnapshot) SetVersion(v int64) { + ps.Version = v +} + +func (ps *PartitionStatsSnapshot) UpdateSegmentStats(segmentID UniqueID, segmentStats SegmentStats) { + ps.SegmentStats[segmentID] = segmentStats +} + +func DeserializePartitionsStatsSnapshot(data []byte) (*PartitionStatsSnapshot, error) { + var messageMap map[string]*json.RawMessage + err := json.Unmarshal(data, &messageMap) + if err != nil { + return nil, err + } + + partitionStats := &PartitionStatsSnapshot{ + SegmentStats: make(map[UniqueID]SegmentStats), + } + err = json.Unmarshal(*messageMap["segmentStats"], &partitionStats.SegmentStats) + if err != nil { + return nil, err + } + return partitionStats, nil +} + +func SerializePartitionStatsSnapshot(partStats *PartitionStatsSnapshot) ([]byte, error) { + partData, err := json.Marshal(partStats) + if err != nil { + return nil, err + } + return partData, nil +} diff --git a/internal/storage/partition_stats_test.go b/internal/storage/partition_stats_test.go new file mode 100644 index 0000000000..e7cd496836 --- /dev/null +++ b/internal/storage/partition_stats_test.go @@ -0,0 +1,77 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestPartitionStats(t *testing.T) { + partStats := NewPartitionStatsSnapshot() + { + fieldStats := make([]FieldStats, 0) + fieldStat1 := FieldStats{ + FieldID: 1, + Type: schemapb.DataType_Int64, + Max: NewInt64FieldValue(200), + Min: NewInt64FieldValue(100), + } + fieldStat2 := FieldStats{ + FieldID: 2, + Type: schemapb.DataType_Int64, + Max: NewInt64FieldValue(200), + Min: NewInt64FieldValue(100), + } + fieldStats = append(fieldStats, fieldStat1) + fieldStats = append(fieldStats, fieldStat2) + + partStats.UpdateSegmentStats(1, SegmentStats{ + FieldStats: fieldStats, + }) + } + { + fieldStat1 := FieldStats{ + FieldID: 1, + Type: schemapb.DataType_Int64, + Max: NewInt64FieldValue(200), + Min: NewInt64FieldValue(100), + } + fieldStat2 := FieldStats{ + FieldID: 2, + Type: schemapb.DataType_Int64, + Max: NewInt64FieldValue(200), + Min: NewInt64FieldValue(100), + } + partStats.UpdateSegmentStats(1, SegmentStats{ + FieldStats: []FieldStats{fieldStat1, fieldStat2}, + }) + } + partStats.SetVersion(100) + assert.Equal(t, int64(100), partStats.GetVersion()) + partBytes, err := SerializePartitionStatsSnapshot(partStats) + assert.NoError(t, err) + assert.NotNil(t, partBytes) + desPartStats, err := DeserializePartitionsStatsSnapshot(partBytes) + assert.NoError(t, err) + assert.NotNil(t, desPartStats) + assert.Equal(t, 1, len(desPartStats.SegmentStats)) + assert.Equal(t, 2, len(desPartStats.SegmentStats[1].FieldStats)) +}