mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
feat: Add nullable vector support for proxy and querynode (#46305)
related: #45993 This commit extends nullable vector support to the proxy layer, querynode, and adds comprehensive validation, search reduce, and field data handling for nullable vectors with sparse storage. Proxy layer changes: - Update validate_util.go checkAligned() with getExpectedVectorRows() helper to validate nullable vector field alignment using valid data count - Update checkFloatVectorFieldData/checkSparseFloatVectorFieldData for nullable vector validation with proper row count expectations - Add FieldDataIdxComputer in typeutil/schema.go for logical-to-physical index translation during search reduce operations - Update search_reduce_util.go reduceSearchResultData to use idxComputers for correct field data indexing with nullable vectors - Update task.go, task_query.go, task_upsert.go for nullable vector handling - Update msg_pack.go with nullable vector field data processing QueryNode layer changes: - Update segments/result.go for nullable vector result handling - Update segments/search_reduce.go with nullable vector offset translation Storage and index changes: - Update data_codec.go and utils.go for nullable vector serialization - Update indexcgowrapper/dataset.go and index.go for nullable vector indexing Utility changes: - Add FieldDataIdxComputer struct with Compute() method for efficient logical-to-physical index mapping across multiple field data - Update EstimateEntitySize() and AppendFieldData() with fieldIdxs parameter - Update funcutil.go with nullable vector support functions <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Full support for nullable vector fields (float, binary, float16, bfloat16, int8, sparse) across ingest, storage, indexing, search and retrieval; logical↔physical offset mapping preserves row semantics. * Client: compaction control and compaction-state APIs. * **Bug Fixes** * Improved validation for adding vector fields (nullable + dimension checks) and corrected search/query behavior for nullable vectors. * **Chores** * Persisted validity maps with indexes and on-disk formats. * **Tests** * Extensive new and updated end-to-end nullable-vector tests. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
This commit is contained in:
parent
e4b0f48bc0
commit
3b599441fd
@ -46,6 +46,7 @@ type Column interface {
|
||||
SetNullable(bool)
|
||||
ValidateNullable() error
|
||||
CompactNullableValues()
|
||||
ValidCount() int
|
||||
}
|
||||
|
||||
var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
|
||||
@ -239,10 +240,39 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
|
||||
}
|
||||
data := x.FloatVector.GetData()
|
||||
dim := int(vectors.GetDim())
|
||||
|
||||
if len(validData) > 0 {
|
||||
if end < 0 {
|
||||
end = len(validData)
|
||||
}
|
||||
vector := make([][]float32, 0, end-begin)
|
||||
dataIdx := 0
|
||||
for i := 0; i < begin; i++ {
|
||||
if validData[i] {
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
for i := begin; i < end; i++ {
|
||||
if validData[i] {
|
||||
v := make([]float32, dim)
|
||||
copy(v, data[dataIdx*dim:(dataIdx+1)*dim])
|
||||
vector = append(vector, v)
|
||||
dataIdx++
|
||||
} else {
|
||||
vector = append(vector, nil)
|
||||
}
|
||||
}
|
||||
col := NewColumnFloatVector(fd.GetFieldName(), dim, vector)
|
||||
col.withValidData(validData[begin:end])
|
||||
col.nullable = true
|
||||
col.sparseMode = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
if end < 0 {
|
||||
end = len(data) / dim
|
||||
}
|
||||
vector := make([][]float32, 0, end-begin) // shall not have remanunt
|
||||
vector := make([][]float32, 0, end-begin)
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]float32, dim)
|
||||
copy(v, data[i*dim:(i+1)*dim])
|
||||
@ -262,6 +292,35 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
|
||||
}
|
||||
dim := int(vectors.GetDim())
|
||||
blen := dim / 8
|
||||
|
||||
if len(validData) > 0 {
|
||||
if end < 0 {
|
||||
end = len(validData)
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin)
|
||||
dataIdx := 0
|
||||
for i := 0; i < begin; i++ {
|
||||
if validData[i] {
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
for i := begin; i < end; i++ {
|
||||
if validData[i] {
|
||||
v := make([]byte, blen)
|
||||
copy(v, data[dataIdx*blen:(dataIdx+1)*blen])
|
||||
vector = append(vector, v)
|
||||
dataIdx++
|
||||
} else {
|
||||
vector = append(vector, nil)
|
||||
}
|
||||
}
|
||||
col := NewColumnBinaryVector(fd.GetFieldName(), dim, vector)
|
||||
col.withValidData(validData[begin:end])
|
||||
col.nullable = true
|
||||
col.sparseMode = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
if end < 0 {
|
||||
end = len(data) / blen
|
||||
}
|
||||
@ -281,13 +340,43 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
|
||||
}
|
||||
data := x.Float16Vector
|
||||
dim := int(vectors.GetDim())
|
||||
bytePerRow := dim * 2
|
||||
|
||||
if len(validData) > 0 {
|
||||
if end < 0 {
|
||||
end = len(validData)
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin)
|
||||
dataIdx := 0
|
||||
for i := 0; i < begin; i++ {
|
||||
if validData[i] {
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
for i := begin; i < end; i++ {
|
||||
if validData[i] {
|
||||
v := make([]byte, bytePerRow)
|
||||
copy(v, data[dataIdx*bytePerRow:(dataIdx+1)*bytePerRow])
|
||||
vector = append(vector, v)
|
||||
dataIdx++
|
||||
} else {
|
||||
vector = append(vector, nil)
|
||||
}
|
||||
}
|
||||
col := NewColumnFloat16Vector(fd.GetFieldName(), dim, vector)
|
||||
col.withValidData(validData[begin:end])
|
||||
col.nullable = true
|
||||
col.sparseMode = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
if end < 0 {
|
||||
end = len(data) / dim / 2
|
||||
end = len(data) / bytePerRow
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin)
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]byte, dim*2)
|
||||
copy(v, data[i*dim*2:(i+1)*dim*2])
|
||||
v := make([]byte, bytePerRow)
|
||||
copy(v, data[i*bytePerRow:(i+1)*bytePerRow])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
|
||||
@ -300,13 +389,43 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
|
||||
}
|
||||
data := x.Bfloat16Vector
|
||||
dim := int(vectors.GetDim())
|
||||
if end < 0 {
|
||||
end = len(data) / dim / 2
|
||||
bytePerRow := dim * 2
|
||||
|
||||
if len(validData) > 0 {
|
||||
if end < 0 {
|
||||
end = len(validData)
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin)
|
||||
dataIdx := 0
|
||||
for i := 0; i < begin; i++ {
|
||||
if validData[i] {
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
for i := begin; i < end; i++ {
|
||||
if validData[i] {
|
||||
v := make([]byte, bytePerRow)
|
||||
copy(v, data[dataIdx*bytePerRow:(dataIdx+1)*bytePerRow])
|
||||
vector = append(vector, v)
|
||||
dataIdx++
|
||||
} else {
|
||||
vector = append(vector, nil)
|
||||
}
|
||||
}
|
||||
col := NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector)
|
||||
col.withValidData(validData[begin:end])
|
||||
col.nullable = true
|
||||
col.sparseMode = true
|
||||
return col, nil
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin) // shall not have remanunt
|
||||
|
||||
if end < 0 {
|
||||
end = len(data) / bytePerRow
|
||||
}
|
||||
vector := make([][]byte, 0, end-begin)
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]byte, dim*2)
|
||||
copy(v, data[i*dim*2:(i+1)*dim*2])
|
||||
v := make([]byte, bytePerRow)
|
||||
copy(v, data[i*bytePerRow:(i+1)*bytePerRow])
|
||||
vector = append(vector, v)
|
||||
}
|
||||
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
|
||||
@ -317,6 +436,37 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
|
||||
return nil, errFieldDataTypeNotMatch
|
||||
}
|
||||
data := sparseVectors.Contents
|
||||
|
||||
if len(validData) > 0 {
|
||||
if end < 0 {
|
||||
end = len(validData)
|
||||
}
|
||||
vectors := make([]entity.SparseEmbedding, 0, end-begin)
|
||||
dataIdx := 0
|
||||
for i := 0; i < begin; i++ {
|
||||
if validData[i] {
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
for i := begin; i < end; i++ {
|
||||
if validData[i] {
|
||||
vector, err := entity.DeserializeSliceSparseEmbedding(data[dataIdx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vectors = append(vectors, vector)
|
||||
dataIdx++
|
||||
} else {
|
||||
vectors = append(vectors, nil)
|
||||
}
|
||||
}
|
||||
col := NewColumnSparseVectors(fd.GetFieldName(), vectors)
|
||||
col.withValidData(validData[begin:end])
|
||||
col.nullable = true
|
||||
col.sparseMode = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
if end < 0 {
|
||||
end = len(data)
|
||||
}
|
||||
@ -339,11 +489,41 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
|
||||
}
|
||||
data := x.Int8Vector
|
||||
dim := int(vectors.GetDim())
|
||||
|
||||
if len(validData) > 0 {
|
||||
if end < 0 {
|
||||
end = len(validData)
|
||||
}
|
||||
vector := make([][]int8, 0, end-begin)
|
||||
dataIdx := 0
|
||||
for i := 0; i < begin; i++ {
|
||||
if validData[i] {
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
for i := begin; i < end; i++ {
|
||||
if validData[i] {
|
||||
v := make([]int8, dim)
|
||||
for j := 0; j < dim; j++ {
|
||||
v[j] = int8(data[dataIdx*dim+j])
|
||||
}
|
||||
vector = append(vector, v)
|
||||
dataIdx++
|
||||
} else {
|
||||
vector = append(vector, nil)
|
||||
}
|
||||
}
|
||||
col := NewColumnInt8Vector(fd.GetFieldName(), dim, vector)
|
||||
col.withValidData(validData[begin:end])
|
||||
col.nullable = true
|
||||
col.sparseMode = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
if end < 0 {
|
||||
end = len(data) / dim
|
||||
}
|
||||
vector := make([][]int8, 0, end-begin) // shall not have remanunt
|
||||
// TODO caiyd: has better way to convert []byte to []int8 ?
|
||||
vector := make([][]int8, 0, end-begin)
|
||||
for i := begin; i < end; i++ {
|
||||
v := make([]int8, dim)
|
||||
for j := 0; j < dim; j++ {
|
||||
|
||||
@ -301,6 +301,19 @@ func (c *genericColumnBase[T]) CompactNullableValues() {
|
||||
c.values = c.values[0:cnt]
|
||||
}
|
||||
|
||||
func (c *genericColumnBase[T]) ValidCount() int {
|
||||
if !c.nullable || len(c.validData) == 0 {
|
||||
return len(c.values)
|
||||
}
|
||||
count := 0
|
||||
for _, v := range c.validData {
|
||||
if v {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (c *genericColumnBase[T]) withValidData(validData []bool) {
|
||||
if len(validData) > 0 {
|
||||
c.nullable = true
|
||||
|
||||
@ -16,6 +16,12 @@
|
||||
|
||||
package column
|
||||
|
||||
import (
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
)
|
||||
|
||||
var (
|
||||
// scalars
|
||||
NewNullableColumnBool NullableColumnCreateFunc[bool, *ColumnBool] = NewNullableColumnCreator(NewColumnBool).New
|
||||
@ -41,6 +47,76 @@ var (
|
||||
NewNullableColumnDoubleArray NullableColumnCreateFunc[[]float64, *ColumnDoubleArray] = NewNullableColumnCreator(NewColumnDoubleArray).New
|
||||
)
|
||||
|
||||
func NewNullableColumnFloatVector(fieldName string, dim int, values [][]float32, validData []bool) (*ColumnFloatVector, error) {
|
||||
if len(values) != getValidCount(validData) {
|
||||
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
|
||||
}
|
||||
col := NewColumnFloatVector(fieldName, dim, values)
|
||||
col.withValidData(validData)
|
||||
col.nullable = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func NewNullableColumnBinaryVector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnBinaryVector, error) {
|
||||
if len(values) != getValidCount(validData) {
|
||||
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
|
||||
}
|
||||
col := NewColumnBinaryVector(fieldName, dim, values)
|
||||
col.withValidData(validData)
|
||||
col.nullable = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func NewNullableColumnFloat16Vector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnFloat16Vector, error) {
|
||||
if len(values) != getValidCount(validData) {
|
||||
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
|
||||
}
|
||||
col := NewColumnFloat16Vector(fieldName, dim, values)
|
||||
col.withValidData(validData)
|
||||
col.nullable = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func NewNullableColumnBFloat16Vector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnBFloat16Vector, error) {
|
||||
if len(values) != getValidCount(validData) {
|
||||
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
|
||||
}
|
||||
col := NewColumnBFloat16Vector(fieldName, dim, values)
|
||||
col.withValidData(validData)
|
||||
col.nullable = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func NewNullableColumnInt8Vector(fieldName string, dim int, values [][]int8, validData []bool) (*ColumnInt8Vector, error) {
|
||||
if len(values) != getValidCount(validData) {
|
||||
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
|
||||
}
|
||||
col := NewColumnInt8Vector(fieldName, dim, values)
|
||||
col.withValidData(validData)
|
||||
col.nullable = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func NewNullableColumnSparseFloatVector(fieldName string, values []entity.SparseEmbedding, validData []bool) (*ColumnSparseFloatVector, error) {
|
||||
if len(values) != getValidCount(validData) {
|
||||
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
|
||||
}
|
||||
col := NewColumnSparseVectors(fieldName, values)
|
||||
col.withValidData(validData)
|
||||
col.nullable = true
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func getValidCount(validData []bool) int {
|
||||
count := 0
|
||||
for _, v := range validData {
|
||||
if v {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
type NullableColumnCreateFunc[T any, Col interface {
|
||||
Column
|
||||
Data() []T
|
||||
|
||||
@ -38,11 +38,15 @@ func NewColumnSparseVectors(name string, values []entity.SparseEmbedding) *Colum
|
||||
|
||||
func (c *ColumnSparseFloatVector) FieldData() *schemapb.FieldData {
|
||||
fd := c.vectorBase.FieldData()
|
||||
max := lo.MaxBy(c.values, func(a, b entity.SparseEmbedding) bool {
|
||||
return a.Dim() > b.Dim()
|
||||
})
|
||||
vectors := fd.GetVectors()
|
||||
vectors.Dim = int64(max.Dim())
|
||||
if len(c.values) > 0 {
|
||||
max := lo.MaxBy(c.values, func(a, b entity.SparseEmbedding) bool {
|
||||
return a.Dim() > b.Dim()
|
||||
})
|
||||
vectors.Dim = int64(max.Dim())
|
||||
} else {
|
||||
vectors.Dim = 0
|
||||
}
|
||||
return fd
|
||||
}
|
||||
|
||||
|
||||
@ -136,3 +136,7 @@ func (c *columnStructArray) CompactNullableValues() {
|
||||
field.CompactNullableValues()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *columnStructArray) ValidCount() int {
|
||||
return c.Len()
|
||||
}
|
||||
|
||||
@ -206,6 +206,17 @@ const (
|
||||
FieldTypeStruct FieldType = 201
|
||||
)
|
||||
|
||||
// IsVectorType returns true if the field type is a vector type
|
||||
func (t FieldType) IsVectorType() bool {
|
||||
switch t {
|
||||
case FieldTypeBinaryVector, FieldTypeFloatVector, FieldTypeFloat16Vector,
|
||||
FieldTypeBFloat16Vector, FieldTypeSparseVector, FieldTypeInt8Vector:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Field represent field schema in milvus
|
||||
type Field struct {
|
||||
ID int64 // field id, generated when collection is created, input value is ignored
|
||||
|
||||
@ -185,6 +185,10 @@ func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption
|
||||
|
||||
// AddCollectionField adds a field to a collection.
|
||||
func (c *Client) AddCollectionField(ctx context.Context, opt AddCollectionFieldOption, callOpts ...grpc.CallOption) error {
|
||||
if err := opt.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := opt.Request()
|
||||
|
||||
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {
|
||||
|
||||
@ -405,6 +405,8 @@ func NewGetCollectionStatsOption(collectionName string) *getCollectionStatsOptio
|
||||
|
||||
type AddCollectionFieldOption interface {
|
||||
Request() *milvuspb.AddCollectionFieldRequest
|
||||
// Validate validates the option before sending request
|
||||
Validate() error
|
||||
}
|
||||
|
||||
type addCollectionFieldOption struct {
|
||||
@ -420,6 +422,15 @@ func (c *addCollectionFieldOption) Request() *milvuspb.AddCollectionFieldRequest
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the option before sending request
|
||||
func (c *addCollectionFieldOption) Validate() error {
|
||||
// Vector fields must be nullable when adding to existing collection
|
||||
if c.fieldSch.DataType.IsVectorType() && !c.fieldSch.Nullable {
|
||||
return fmt.Errorf("adding vector field to existing collection requires nullable=true, field name = %s", c.fieldSch.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAddCollectionFieldOption(collectionName string, field *entity.Field) *addCollectionFieldOption {
|
||||
return &addCollectionFieldOption{
|
||||
collectionName: collectionName,
|
||||
|
||||
@ -441,6 +441,37 @@ func (s *CollectionSuite) TestAddCollectionField() {
|
||||
err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field))
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
s.Run("vector_field_without_nullable", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
fieldName := fmt.Sprintf("field_%s", s.randString(6))
|
||||
// no mock expected because validation should fail before RPC call
|
||||
|
||||
field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(128)
|
||||
|
||||
err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field))
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "adding vector field to existing collection requires nullable=true")
|
||||
})
|
||||
|
||||
s.Run("vector_field_with_nullable", func() {
|
||||
collName := fmt.Sprintf("coll_%s", s.randString(6))
|
||||
fieldName := fmt.Sprintf("field_%s", s.randString(6))
|
||||
s.mock.EXPECT().AddCollectionField(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, acfr *milvuspb.AddCollectionFieldRequest) (*commonpb.Status, error) {
|
||||
fieldProto := &schemapb.FieldSchema{}
|
||||
err := proto.Unmarshal(acfr.GetSchema(), fieldProto)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(fieldName, fieldProto.GetName())
|
||||
s.Equal(schemapb.DataType_FloatVector, fieldProto.GetDataType())
|
||||
s.True(fieldProto.GetNullable())
|
||||
return merr.Success(), nil
|
||||
}).Once()
|
||||
|
||||
field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(128).WithNullable(true)
|
||||
|
||||
err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field))
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCollection(t *testing.T) {
|
||||
|
||||
@ -126,6 +126,11 @@ class Chunk {
|
||||
return data_;
|
||||
}
|
||||
|
||||
FixedVector<bool>&
|
||||
Valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
virtual bool
|
||||
isValid(int offset) const {
|
||||
if (nullable_) {
|
||||
@ -559,17 +564,32 @@ class SparseFloatVectorChunk : public Chunk {
|
||||
bool nullable,
|
||||
std::shared_ptr<ChunkMmapGuard> chunk_mmap_guard)
|
||||
: Chunk(row_nums, data, size, nullable, chunk_mmap_guard) {
|
||||
vec_.resize(row_nums);
|
||||
auto null_bitmap_bytes_num = nullable ? (row_nums + 7) / 8 : 0;
|
||||
auto offsets_ptr =
|
||||
reinterpret_cast<uint64_t*>(data + null_bitmap_bytes_num);
|
||||
for (int i = 0; i < row_nums; i++) {
|
||||
vec_[i] = {(offsets_ptr[i + 1] - offsets_ptr[i]) /
|
||||
knowhere::sparse::SparseRow<
|
||||
SparseValueType>::element_size(),
|
||||
reinterpret_cast<uint8_t*>(data + offsets_ptr[i]),
|
||||
false};
|
||||
dim_ = std::max(dim_, vec_[i].dim());
|
||||
|
||||
if (nullable_) {
|
||||
for (int i = 0; i < row_nums; i++) {
|
||||
if (isValid(i)) {
|
||||
vec_.emplace_back(
|
||||
(offsets_ptr[i + 1] - offsets_ptr[i]) /
|
||||
knowhere::sparse::SparseRow<
|
||||
SparseValueType>::element_size(),
|
||||
reinterpret_cast<uint8_t*>(data + offsets_ptr[i]),
|
||||
false);
|
||||
dim_ = std::max(dim_, vec_.back().dim());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vec_.resize(row_nums);
|
||||
for (int i = 0; i < row_nums; i++) {
|
||||
vec_[i] = {(offsets_ptr[i + 1] - offsets_ptr[i]) /
|
||||
knowhere::sparse::SparseRow<
|
||||
SparseValueType>::element_size(),
|
||||
reinterpret_cast<uint8_t*>(data + offsets_ptr[i]),
|
||||
false};
|
||||
dim_ = std::max(dim_, vec_[i].dim());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -435,8 +435,10 @@ SparseFloatVectorChunkWriter::calculate_size(
|
||||
for (const auto& data : array_vec) {
|
||||
auto array = std::dynamic_pointer_cast<arrow::BinaryArray>(data);
|
||||
for (int64_t i = 0; i < array->length(); ++i) {
|
||||
auto str = array->GetView(i);
|
||||
size += str.size();
|
||||
if (!nullable_ || !array->IsNull(i)) {
|
||||
auto str = array->GetView(i);
|
||||
size += str.size();
|
||||
}
|
||||
}
|
||||
row_nums_ += array->length();
|
||||
}
|
||||
@ -459,8 +461,10 @@ SparseFloatVectorChunkWriter::write_to_target(
|
||||
for (const auto& data : array_vec) {
|
||||
auto array = std::dynamic_pointer_cast<arrow::BinaryArray>(data);
|
||||
for (int64_t i = 0; i < array->length(); ++i) {
|
||||
auto str = array->GetView(i);
|
||||
strs.emplace_back(str);
|
||||
if (!nullable_ || !array->IsNull(i)) {
|
||||
auto str = array->GetView(i);
|
||||
strs.emplace_back(str);
|
||||
}
|
||||
}
|
||||
if (nullable_) {
|
||||
null_bitmaps.emplace_back(
|
||||
@ -478,9 +482,23 @@ SparseFloatVectorChunkWriter::write_to_target(
|
||||
std::vector<uint64_t> offsets;
|
||||
offsets.reserve(offset_num);
|
||||
|
||||
for (const auto& str : strs) {
|
||||
offsets.push_back(offset_start_pos);
|
||||
offset_start_pos += str.size();
|
||||
if (nullable_) {
|
||||
size_t str_idx = 0;
|
||||
for (const auto& data : array_vec) {
|
||||
auto array = std::dynamic_pointer_cast<arrow::BinaryArray>(data);
|
||||
for (int i = 0; i < array->length(); i++) {
|
||||
offsets.push_back(offset_start_pos);
|
||||
if (!array->IsNull(i)) {
|
||||
offset_start_pos += strs[str_idx].size();
|
||||
str_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (const auto& str : strs) {
|
||||
offsets.push_back(offset_start_pos);
|
||||
offset_start_pos += str.size();
|
||||
}
|
||||
}
|
||||
offsets.push_back(offset_start_pos);
|
||||
|
||||
@ -524,22 +542,43 @@ create_chunk_writer(const FieldMeta& field_meta) {
|
||||
return std::make_shared<ChunkWriter<arrow::Int64Array, int64_t>>(
|
||||
dim, nullable);
|
||||
case milvus::DataType::VECTOR_FLOAT:
|
||||
if (nullable) {
|
||||
return std::make_shared<
|
||||
NullableVectorChunkWriter<knowhere::fp32>>(dim, nullable);
|
||||
}
|
||||
return std::make_shared<
|
||||
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::fp32>>(
|
||||
dim, nullable);
|
||||
case milvus::DataType::VECTOR_BINARY:
|
||||
if (nullable) {
|
||||
return std::make_shared<
|
||||
NullableVectorChunkWriter<knowhere::bin1>>(dim / 8,
|
||||
nullable);
|
||||
}
|
||||
return std::make_shared<
|
||||
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::bin1>>(
|
||||
dim / 8, nullable);
|
||||
case milvus::DataType::VECTOR_FLOAT16:
|
||||
if (nullable) {
|
||||
return std::make_shared<
|
||||
NullableVectorChunkWriter<knowhere::fp16>>(dim, nullable);
|
||||
}
|
||||
return std::make_shared<
|
||||
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::fp16>>(
|
||||
dim, nullable);
|
||||
case milvus::DataType::VECTOR_BFLOAT16:
|
||||
if (nullable) {
|
||||
return std::make_shared<
|
||||
NullableVectorChunkWriter<knowhere::bf16>>(dim, nullable);
|
||||
}
|
||||
return std::make_shared<
|
||||
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::bf16>>(
|
||||
dim, nullable);
|
||||
case milvus::DataType::VECTOR_INT8:
|
||||
if (nullable) {
|
||||
return std::make_shared<
|
||||
NullableVectorChunkWriter<knowhere::int8>>(dim, nullable);
|
||||
}
|
||||
return std::make_shared<
|
||||
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::int8>>(
|
||||
dim, nullable);
|
||||
|
||||
@ -129,6 +129,57 @@ class ChunkWriter final : public ChunkWriterBase {
|
||||
const int64_t dim_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NullableVectorChunkWriter final : public ChunkWriterBase {
|
||||
public:
|
||||
NullableVectorChunkWriter(int64_t dim, bool nullable)
|
||||
: ChunkWriterBase(nullable), dim_(dim) {
|
||||
Assert(nullable && "NullableVectorChunkWriter requires nullable=true");
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t>
|
||||
calculate_size(const arrow::ArrayVector& array_vec) override {
|
||||
size_t size = 0;
|
||||
size_t row_nums = 0;
|
||||
|
||||
for (const auto& data : array_vec) {
|
||||
row_nums += data->length();
|
||||
auto binary_array =
|
||||
std::static_pointer_cast<arrow::BinaryArray>(data);
|
||||
int64_t valid_count = data->length() - binary_array->null_count();
|
||||
size += valid_count * dim_ * sizeof(T);
|
||||
}
|
||||
|
||||
// null bitmap size
|
||||
size += (row_nums + 7) / 8;
|
||||
row_nums_ = row_nums;
|
||||
return {size, row_nums};
|
||||
}
|
||||
|
||||
void
|
||||
write_to_target(const arrow::ArrayVector& array_vec,
|
||||
const std::shared_ptr<ChunkTarget>& target) override {
|
||||
std::vector<std::tuple<const uint8_t*, int64_t, int64_t>> null_bitmaps;
|
||||
for (const auto& data : array_vec) {
|
||||
null_bitmaps.emplace_back(
|
||||
data->null_bitmap_data(), data->length(), data->offset());
|
||||
}
|
||||
write_null_bit_maps(null_bitmaps, target);
|
||||
|
||||
for (const auto& data : array_vec) {
|
||||
auto binary_array =
|
||||
std::static_pointer_cast<arrow::BinaryArray>(data);
|
||||
auto data_offset = binary_array->value_offset(0);
|
||||
auto data_ptr = binary_array->value_data()->data() + data_offset;
|
||||
int64_t valid_count = data->length() - binary_array->null_count();
|
||||
target->write(data_ptr, valid_count * dim_ * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const int64_t dim_;
|
||||
};
|
||||
|
||||
template <>
|
||||
inline void
|
||||
ChunkWriter<arrow::BooleanArray, bool>::write_to_target(
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "arrow/array/array_binary.h"
|
||||
#include "arrow/chunked_array.h"
|
||||
#include "bitset/detail/element_wise.h"
|
||||
#include "bitset/detail/popcount.h"
|
||||
#include "common/Array.h"
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Exception.h"
|
||||
@ -310,6 +311,18 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
|
||||
case DataType::VECTOR_BFLOAT16:
|
||||
case DataType::VECTOR_INT8:
|
||||
case DataType::VECTOR_BINARY: {
|
||||
if (nullable_) {
|
||||
auto binary_array =
|
||||
std::dynamic_pointer_cast<arrow::BinaryArray>(array);
|
||||
AssertInfo(binary_array != nullptr,
|
||||
"nullable vector must use BinaryArray");
|
||||
auto data_offset = binary_array->value_offset(0);
|
||||
return FillFieldData(
|
||||
binary_array->value_data()->data() + data_offset,
|
||||
binary_array->null_bitmap_data(),
|
||||
binary_array->length(),
|
||||
binary_array->offset());
|
||||
}
|
||||
auto array_info =
|
||||
GetDataInfoFromArray<arrow::FixedSizeBinaryArray,
|
||||
arrow::Type::type::FIXED_SIZE_BINARY>(
|
||||
@ -321,6 +334,20 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
|
||||
"inconsistent data type");
|
||||
auto arr = std::dynamic_pointer_cast<arrow::BinaryArray>(array);
|
||||
std::vector<knowhere::sparse::SparseRow<SparseValueType>> values;
|
||||
|
||||
if (nullable_) {
|
||||
for (int64_t i = 0; i < element_count; ++i) {
|
||||
if (arr->IsValid(i)) {
|
||||
auto view = arr->GetString(i);
|
||||
values.push_back(
|
||||
CopyAndWrapSparseRow(view.data(), view.size()));
|
||||
}
|
||||
}
|
||||
return FillFieldData(values.data(),
|
||||
arr->null_bitmap_data(),
|
||||
arr->length(),
|
||||
arr->offset());
|
||||
}
|
||||
for (size_t index = 0; index < element_count; ++index) {
|
||||
auto view = arr->GetString(index);
|
||||
values.push_back(
|
||||
@ -572,6 +599,96 @@ template class FieldDataImpl<knowhere::sparse::SparseRow<SparseValueType>,
|
||||
true>;
|
||||
template class FieldDataImpl<VectorArray, true>;
|
||||
|
||||
template <typename Type, bool is_type_entire_row>
|
||||
void
|
||||
FieldDataVectorImpl<Type, is_type_entire_row>::FillFieldData(
|
||||
const void* field_data,
|
||||
const uint8_t* valid_data,
|
||||
ssize_t total_element_count,
|
||||
ssize_t offset) {
|
||||
AssertInfo(this->nullable_, "requires nullable to be true");
|
||||
if (total_element_count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t valid_count = 0;
|
||||
if (valid_data) {
|
||||
int64_t bit_start = offset;
|
||||
int64_t bit_end = offset + total_element_count;
|
||||
|
||||
// Handle head: unaligned bits before first full byte
|
||||
int64_t first_full_byte = (bit_start + 7) / 8;
|
||||
int64_t last_full_byte = bit_end / 8;
|
||||
|
||||
// Process unaligned head bits
|
||||
for (int64_t bit_idx = bit_start;
|
||||
bit_idx < std::min(first_full_byte * 8, bit_end);
|
||||
++bit_idx) {
|
||||
if ((valid_data[bit_idx >> 3] >> (bit_idx & 7)) & 1) {
|
||||
valid_count++;
|
||||
}
|
||||
}
|
||||
|
||||
// Process aligned full bytes with popcount
|
||||
for (int64_t byte_idx = first_full_byte; byte_idx < last_full_byte;
|
||||
++byte_idx) {
|
||||
valid_count += bitset::detail::PopCountHelper<uint8_t>::count(
|
||||
valid_data[byte_idx]);
|
||||
}
|
||||
|
||||
// Process unaligned tail bits
|
||||
for (int64_t bit_idx =
|
||||
std::max(last_full_byte * 8, first_full_byte * 8);
|
||||
bit_idx < bit_end;
|
||||
++bit_idx) {
|
||||
if ((valid_data[bit_idx >> 3] >> (bit_idx & 7)) & 1) {
|
||||
valid_count++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
valid_count = total_element_count;
|
||||
}
|
||||
|
||||
std::lock_guard lck(this->tell_mutex_);
|
||||
resize_field_data(this->length_ + total_element_count,
|
||||
this->valid_count_ + valid_count);
|
||||
|
||||
if (valid_data) {
|
||||
bitset::detail::ElementWiseBitsetPolicy<uint8_t>::op_copy(
|
||||
valid_data,
|
||||
offset,
|
||||
this->valid_data_.data(),
|
||||
this->length_,
|
||||
total_element_count);
|
||||
}
|
||||
|
||||
// update logical to physical offset mapping
|
||||
l2p_mapping_.build(this->valid_data_.data(),
|
||||
this->valid_count_,
|
||||
this->length_,
|
||||
total_element_count,
|
||||
valid_count);
|
||||
|
||||
if (valid_count > 0) {
|
||||
std::copy_n(static_cast<const Type*>(field_data),
|
||||
valid_count * this->dim_,
|
||||
this->data_.data() + this->valid_count_ * this->dim_);
|
||||
this->valid_count_ += valid_count;
|
||||
}
|
||||
|
||||
this->null_count_ = total_element_count - valid_count;
|
||||
this->length_ += total_element_count;
|
||||
}
|
||||
|
||||
// explicit instantiations for FieldDataVectorImpl
|
||||
template class FieldDataVectorImpl<uint8_t, false>;
|
||||
template class FieldDataVectorImpl<int8_t, false>;
|
||||
template class FieldDataVectorImpl<float, false>;
|
||||
template class FieldDataVectorImpl<float16, false>;
|
||||
template class FieldDataVectorImpl<bfloat16, false>;
|
||||
template class FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
|
||||
true>;
|
||||
|
||||
FieldDataPtr
|
||||
InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows) {
|
||||
switch (type) {
|
||||
|
||||
@ -19,6 +19,8 @@
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <oneapi/tbb/concurrent_queue.h>
|
||||
|
||||
@ -133,24 +135,251 @@ class FieldData<VectorArray> : public FieldDataVectorArrayImpl {
|
||||
DataType element_type_;
|
||||
};
|
||||
|
||||
template <typename Type, bool is_type_entire_row = false>
|
||||
class FieldDataVectorImpl : public FieldDataImpl<Type, is_type_entire_row> {
|
||||
private:
|
||||
struct LogicalToPhysicalMapping {
|
||||
bool mapping{false};
|
||||
std::unordered_map<int64_t, int64_t> l2p_map;
|
||||
std::vector<int64_t> l2p_vec;
|
||||
|
||||
int64_t
|
||||
get_physical_offset(int64_t logical_offset) const {
|
||||
if (!mapping) {
|
||||
return logical_offset;
|
||||
}
|
||||
if (!l2p_map.empty()) {
|
||||
auto it = l2p_map.find(logical_offset);
|
||||
if (it != l2p_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
if (logical_offset < static_cast<int64_t>(l2p_vec.size())) {
|
||||
return l2p_vec[logical_offset];
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void
|
||||
build(const uint8_t* valid_data,
|
||||
int64_t start_physical,
|
||||
int64_t start_logical,
|
||||
int64_t total_count,
|
||||
int64_t valid_count) {
|
||||
if (total_count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
mapping = true;
|
||||
|
||||
// use map when valid ratio < 10%
|
||||
bool use_map = (valid_count * 10 < total_count);
|
||||
|
||||
if (use_map) {
|
||||
int64_t physical_idx = start_physical;
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
int64_t bit_pos = start_logical + i;
|
||||
if (valid_data == nullptr ||
|
||||
((valid_data[bit_pos >> 3] >> (bit_pos & 0x07)) & 1)) {
|
||||
l2p_map[start_logical + i] = physical_idx++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// resize l2p_vec if needed
|
||||
int64_t required_size = start_logical + total_count;
|
||||
if (static_cast<int64_t>(l2p_vec.size()) < required_size) {
|
||||
l2p_vec.resize(required_size, -1);
|
||||
}
|
||||
int64_t physical_idx = start_physical;
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
int64_t bit_pos = start_logical + i;
|
||||
if (valid_data == nullptr ||
|
||||
((valid_data[bit_pos >> 3] >> (bit_pos & 0x07)) & 1)) {
|
||||
l2p_vec[start_logical + i] = physical_idx++;
|
||||
} else {
|
||||
l2p_vec[start_logical + i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void
|
||||
resize_field_data(int64_t num_rows, int64_t valid_count) {
|
||||
Assert(this->nullable_);
|
||||
std::lock_guard lck(this->num_rows_mutex_);
|
||||
if (num_rows > this->num_rows_) {
|
||||
this->num_rows_ = num_rows;
|
||||
this->valid_data_.resize((num_rows + 7) / 8, 0x00);
|
||||
}
|
||||
if (valid_count > this->valid_count_) {
|
||||
this->data_.resize(valid_count * this->dim_);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalToPhysicalMapping l2p_mapping_;
|
||||
|
||||
public:
|
||||
using FieldDataImpl<Type, is_type_entire_row>::FieldDataImpl;
|
||||
using FieldDataImpl<Type, is_type_entire_row>::resize_field_data;
|
||||
|
||||
void
|
||||
FillFieldData(const void* field_data,
|
||||
const uint8_t* valid_data,
|
||||
ssize_t element_count,
|
||||
ssize_t offset) override;
|
||||
|
||||
const void*
|
||||
RawValue(ssize_t offset) const override {
|
||||
auto physical_offset = l2p_mapping_.get_physical_offset(offset);
|
||||
if (physical_offset == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
return &this->data_[physical_offset * this->dim_];
|
||||
}
|
||||
|
||||
int64_t
|
||||
DataSize() const override {
|
||||
auto dim = this->dim_;
|
||||
if (this->nullable_) {
|
||||
return sizeof(Type) * this->valid_count_ * dim;
|
||||
}
|
||||
return sizeof(Type) * this->length_ * dim;
|
||||
}
|
||||
|
||||
int64_t
|
||||
DataSize(ssize_t offset) const override {
|
||||
auto dim = this->dim_;
|
||||
AssertInfo(offset < this->get_num_rows(),
|
||||
"field data subscript out of range");
|
||||
return sizeof(Type) * dim;
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_valid_rows() const override {
|
||||
if (this->nullable_) {
|
||||
return this->valid_count_;
|
||||
}
|
||||
return this->get_num_rows();
|
||||
}
|
||||
};
|
||||
|
||||
class FieldDataSparseVectorImpl
|
||||
: public FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
|
||||
true> {
|
||||
using Base =
|
||||
FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>, true>;
|
||||
|
||||
public:
|
||||
// Bring base class FillFieldData overloads into scope (for nullable support)
|
||||
using Base::FillFieldData;
|
||||
|
||||
explicit FieldDataSparseVectorImpl(DataType data_type,
|
||||
bool nullable = false,
|
||||
int64_t total_num_rows = 0)
|
||||
: FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
|
||||
true>(
|
||||
/*dim=*/1, data_type, nullable, total_num_rows),
|
||||
vec_dim_(0) {
|
||||
AssertInfo(data_type == DataType::VECTOR_SPARSE_U32_F32,
|
||||
"invalid data type for sparse vector");
|
||||
}
|
||||
|
||||
int64_t
|
||||
DataSize() const override {
|
||||
int64_t data_size = 0;
|
||||
size_t count = nullable_ ? valid_count_ : length_;
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
data_size += data_[i].data_byte_size();
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
int64_t
|
||||
DataSize(ssize_t offset) const override {
|
||||
AssertInfo(offset < get_num_rows(),
|
||||
"field data subscript out of range");
|
||||
size_t count = nullable_ ? valid_count_ : length_;
|
||||
AssertInfo(
|
||||
offset < count,
|
||||
"subscript position don't has valid value offset={}, count={}",
|
||||
offset,
|
||||
count);
|
||||
return data_[offset].data_byte_size();
|
||||
}
|
||||
|
||||
void
|
||||
FillFieldData(const void* source, ssize_t element_count) override {
|
||||
if (element_count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard lck(tell_mutex_);
|
||||
if (length_ + element_count > get_num_rows()) {
|
||||
FieldDataImpl::resize_field_data(length_ + element_count);
|
||||
}
|
||||
auto ptr =
|
||||
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
|
||||
source);
|
||||
for (int64_t i = 0; i < element_count; ++i) {
|
||||
auto& row = ptr[i];
|
||||
vec_dim_ = std::max(vec_dim_, row.dim());
|
||||
}
|
||||
std::copy_n(ptr, element_count, data_.data() + length_);
|
||||
length_ += element_count;
|
||||
}
|
||||
|
||||
void
|
||||
FillFieldData(const std::shared_ptr<arrow::BinaryArray>& array) override {
|
||||
auto n = array->length();
|
||||
if (n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard lck(tell_mutex_);
|
||||
if (length_ + n > get_num_rows()) {
|
||||
FieldDataImpl::resize_field_data(length_ + n);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < array->length(); ++i) {
|
||||
auto view = array->GetView(i);
|
||||
auto& row = data_[length_ + i];
|
||||
row = CopyAndWrapSparseRow(view.data(), view.size());
|
||||
vec_dim_ = std::max(vec_dim_, row.dim());
|
||||
}
|
||||
length_ += n;
|
||||
}
|
||||
|
||||
int64_t
|
||||
Dim() const {
|
||||
return vec_dim_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t vec_dim_ = 0;
|
||||
};
|
||||
|
||||
template <>
|
||||
class FieldData<FloatVector> : public FieldDataImpl<float, false> {
|
||||
class FieldData<FloatVector> : public FieldDataVectorImpl<float, false> {
|
||||
public:
|
||||
explicit FieldData(int64_t dim,
|
||||
DataType data_type,
|
||||
bool nullable,
|
||||
int64_t buffered_num_rows = 0)
|
||||
: FieldDataImpl<float, false>::FieldDataImpl(
|
||||
dim, data_type, false, buffered_num_rows) {
|
||||
: FieldDataVectorImpl<float, false>::FieldDataVectorImpl(
|
||||
dim, data_type, nullable, buffered_num_rows) {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class FieldData<BinaryVector> : public FieldDataImpl<uint8_t, false> {
|
||||
class FieldData<BinaryVector> : public FieldDataVectorImpl<uint8_t, false> {
|
||||
public:
|
||||
explicit FieldData(int64_t dim,
|
||||
DataType data_type,
|
||||
bool nullable,
|
||||
int64_t buffered_num_rows = 0)
|
||||
: FieldDataImpl(dim / 8, data_type, false, buffered_num_rows),
|
||||
: FieldDataVectorImpl(dim / 8, data_type, nullable, buffered_num_rows),
|
||||
binary_dim_(dim) {
|
||||
Assert(dim % 8 == 0);
|
||||
}
|
||||
@ -165,43 +394,48 @@ class FieldData<BinaryVector> : public FieldDataImpl<uint8_t, false> {
|
||||
};
|
||||
|
||||
template <>
|
||||
class FieldData<Float16Vector> : public FieldDataImpl<float16, false> {
|
||||
class FieldData<Float16Vector> : public FieldDataVectorImpl<float16, false> {
|
||||
public:
|
||||
explicit FieldData(int64_t dim,
|
||||
DataType data_type,
|
||||
bool nullable,
|
||||
int64_t buffered_num_rows = 0)
|
||||
: FieldDataImpl<float16, false>::FieldDataImpl(
|
||||
dim, data_type, false, buffered_num_rows) {
|
||||
: FieldDataVectorImpl<float16, false>::FieldDataVectorImpl(
|
||||
dim, data_type, nullable, buffered_num_rows) {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class FieldData<BFloat16Vector> : public FieldDataImpl<bfloat16, false> {
|
||||
class FieldData<BFloat16Vector> : public FieldDataVectorImpl<bfloat16, false> {
|
||||
public:
|
||||
explicit FieldData(int64_t dim,
|
||||
DataType data_type,
|
||||
bool nullable,
|
||||
int64_t buffered_num_rows = 0)
|
||||
: FieldDataImpl<bfloat16, false>::FieldDataImpl(
|
||||
dim, data_type, false, buffered_num_rows) {
|
||||
: FieldDataVectorImpl<bfloat16, false>::FieldDataVectorImpl(
|
||||
dim, data_type, nullable, buffered_num_rows) {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class FieldData<SparseFloatVector> : public FieldDataSparseVectorImpl {
|
||||
public:
|
||||
explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0)
|
||||
: FieldDataSparseVectorImpl(data_type, buffered_num_rows) {
|
||||
explicit FieldData(DataType data_type,
|
||||
bool nullable = false,
|
||||
int64_t buffered_num_rows = 0)
|
||||
: FieldDataSparseVectorImpl(data_type, nullable, buffered_num_rows) {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class FieldData<Int8Vector> : public FieldDataImpl<int8, false> {
|
||||
class FieldData<Int8Vector> : public FieldDataVectorImpl<int8, false> {
|
||||
public:
|
||||
explicit FieldData(int64_t dim,
|
||||
DataType data_type,
|
||||
bool nullable,
|
||||
int64_t buffered_num_rows = 0)
|
||||
: FieldDataImpl<int8, false>::FieldDataImpl(
|
||||
dim, data_type, false, buffered_num_rows) {
|
||||
: FieldDataVectorImpl<int8, false>::FieldDataVectorImpl(
|
||||
dim, data_type, nullable, buffered_num_rows) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -127,6 +127,9 @@ class FieldDataBase {
|
||||
virtual bool
|
||||
is_valid(ssize_t offset) const = 0;
|
||||
|
||||
virtual int64_t
|
||||
get_valid_rows() const = 0;
|
||||
|
||||
protected:
|
||||
const DataType data_type_;
|
||||
const bool nullable_;
|
||||
@ -309,6 +312,12 @@ class FieldBitsetImpl : public FieldDataBase {
|
||||
"is_valid(ssize_t offset) not implemented for bitset");
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_valid_rows() const override {
|
||||
ThrowInfo(NotImplemented,
|
||||
"get_valid_rows() not implemented for bitset");
|
||||
}
|
||||
|
||||
private:
|
||||
FixedVector<Type> data_{};
|
||||
// capacity that data_ can store
|
||||
@ -340,9 +349,6 @@ class FieldDataImpl : public FieldDataBase {
|
||||
dim_(is_type_entire_row ? 1 : dim) {
|
||||
data_.resize(num_rows_ * dim_);
|
||||
if (nullable) {
|
||||
if (IsVectorDataType(data_type)) {
|
||||
ThrowInfo(NotImplemented, "vector type not support null");
|
||||
}
|
||||
valid_data_.resize((num_rows_ + 7) / 8, 0xFF);
|
||||
}
|
||||
}
|
||||
@ -492,6 +498,12 @@ class FieldDataImpl : public FieldDataBase {
|
||||
return num_rows_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_valid_rows() const override {
|
||||
std::shared_lock lck(tell_mutex_);
|
||||
return static_cast<int64_t>(length_) - null_count_;
|
||||
}
|
||||
|
||||
void
|
||||
resize_field_data(int64_t num_rows) {
|
||||
std::lock_guard lck(num_rows_mutex_);
|
||||
@ -540,13 +552,12 @@ class FieldDataImpl : public FieldDataBase {
|
||||
FixedVector<uint8_t> valid_data_{};
|
||||
// number of elements data_ can hold
|
||||
int64_t num_rows_;
|
||||
size_t valid_count_{0};
|
||||
mutable std::shared_mutex num_rows_mutex_;
|
||||
int64_t null_count_{0};
|
||||
// number of actual elements in data_
|
||||
size_t length_{};
|
||||
mutable std::shared_mutex tell_mutex_;
|
||||
|
||||
private:
|
||||
const ssize_t dim_;
|
||||
};
|
||||
|
||||
@ -803,90 +814,6 @@ class FieldDataJsonImpl : public FieldDataImpl<Json, true> {
|
||||
}
|
||||
};
|
||||
|
||||
class FieldDataSparseVectorImpl
|
||||
: public FieldDataImpl<knowhere::sparse::SparseRow<SparseValueType>, true> {
|
||||
public:
|
||||
explicit FieldDataSparseVectorImpl(DataType data_type,
|
||||
int64_t total_num_rows = 0)
|
||||
: FieldDataImpl<knowhere::sparse::SparseRow<SparseValueType>, true>(
|
||||
/*dim=*/1, data_type, false, total_num_rows),
|
||||
vec_dim_(0) {
|
||||
AssertInfo(data_type == DataType::VECTOR_SPARSE_U32_F32,
|
||||
"invalid data type for sparse vector");
|
||||
}
|
||||
|
||||
int64_t
|
||||
DataSize() const override {
|
||||
int64_t data_size = 0;
|
||||
for (size_t i = 0; i < length(); ++i) {
|
||||
data_size += data_[i].data_byte_size();
|
||||
}
|
||||
return data_size;
|
||||
}
|
||||
|
||||
int64_t
|
||||
DataSize(ssize_t offset) const override {
|
||||
AssertInfo(offset < get_num_rows(),
|
||||
"field data subscript out of range");
|
||||
AssertInfo(offset < length(),
|
||||
"subscript position don't has valid value");
|
||||
return data_[offset].data_byte_size();
|
||||
}
|
||||
|
||||
// source is a pointer to element_count of
|
||||
// knowhere::sparse::SparseRow<SparseValueType>
|
||||
void
|
||||
FillFieldData(const void* source, ssize_t element_count) override {
|
||||
if (element_count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard lck(tell_mutex_);
|
||||
if (length_ + element_count > get_num_rows()) {
|
||||
resize_field_data(length_ + element_count);
|
||||
}
|
||||
auto ptr =
|
||||
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
|
||||
source);
|
||||
for (int64_t i = 0; i < element_count; ++i) {
|
||||
auto& row = ptr[i];
|
||||
vec_dim_ = std::max(vec_dim_, row.dim());
|
||||
}
|
||||
std::copy_n(ptr, element_count, data_.data() + length_);
|
||||
length_ += element_count;
|
||||
}
|
||||
|
||||
// each binary in array is a knowhere::sparse::SparseRow<SparseValueType>
|
||||
void
|
||||
FillFieldData(const std::shared_ptr<arrow::BinaryArray>& array) override {
|
||||
auto n = array->length();
|
||||
if (n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard lck(tell_mutex_);
|
||||
if (length_ + n > get_num_rows()) {
|
||||
resize_field_data(length_ + n);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < array->length(); ++i) {
|
||||
auto view = array->GetView(i);
|
||||
auto& row = data_[length_ + i];
|
||||
row = CopyAndWrapSparseRow(view.data(), view.size());
|
||||
vec_dim_ = std::max(vec_dim_, row.dim());
|
||||
}
|
||||
length_ += n;
|
||||
}
|
||||
|
||||
int64_t
|
||||
Dim() const {
|
||||
return vec_dim_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t vec_dim_ = 0;
|
||||
};
|
||||
|
||||
class FieldDataArrayImpl : public FieldDataImpl<Array, true> {
|
||||
public:
|
||||
explicit FieldDataArrayImpl(DataType data_type,
|
||||
|
||||
@ -168,6 +168,8 @@ FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) {
|
||||
}
|
||||
|
||||
if (IsVectorDataType(data_type)) {
|
||||
AssertInfo(!default_value.has_value(),
|
||||
"vector fields do not support default values");
|
||||
auto type_map = RepeatedKeyValToMap(schema_proto.type_params());
|
||||
auto index_map = RepeatedKeyValToMap(schema_proto.index_params());
|
||||
|
||||
@ -183,12 +185,17 @@ FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) {
|
||||
data_type,
|
||||
dim,
|
||||
std::nullopt,
|
||||
false,
|
||||
nullable,
|
||||
default_value};
|
||||
}
|
||||
auto metric_type = index_map.at("metric_type");
|
||||
return FieldMeta{
|
||||
name, field_id, data_type, dim, metric_type, false, default_value};
|
||||
return FieldMeta{name,
|
||||
field_id,
|
||||
data_type,
|
||||
dim,
|
||||
metric_type,
|
||||
nullable,
|
||||
default_value};
|
||||
}
|
||||
|
||||
if (IsStringDataType(data_type)) {
|
||||
|
||||
@ -125,7 +125,8 @@ class FieldMeta {
|
||||
vector_info_(VectorInfo{dim, std::move(metric_type)}),
|
||||
default_value_(std::move(default_value)) {
|
||||
Assert(IsVectorDataType(type_));
|
||||
Assert(!nullable);
|
||||
Assert(!default_value_.has_value() &&
|
||||
"vector fields do not support default values");
|
||||
}
|
||||
|
||||
// array of vector type
|
||||
|
||||
194
internal/core/src/common/OffsetMapping.cpp
Normal file
194
internal/core/src/common/OffsetMapping.cpp
Normal file
@ -0,0 +1,194 @@
|
||||
#include "common/OffsetMapping.h"
|
||||
|
||||
namespace milvus {
|
||||
|
||||
void
|
||||
OffsetMapping::Build(const bool* valid_data,
|
||||
int64_t total_count,
|
||||
int64_t start_logical,
|
||||
int64_t start_physical) {
|
||||
if (total_count == 0 || valid_data == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock<std::shared_mutex> lck(mutex_);
|
||||
enabled_ = true;
|
||||
total_count_ = start_logical + total_count;
|
||||
|
||||
// Count valid elements first
|
||||
int64_t valid_count = 0;
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
if (valid_data[i]) {
|
||||
valid_count++;
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-select storage mode: use map when valid ratio < 10%
|
||||
use_map_ = (valid_count * 10 < total_count);
|
||||
|
||||
if (use_map_) {
|
||||
// Map mode: only store valid entries
|
||||
int64_t physical_idx = start_physical;
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
if (valid_data[i]) {
|
||||
l2p_map_[start_logical + i] = physical_idx;
|
||||
p2l_map_[physical_idx] = start_logical + i;
|
||||
physical_idx++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Vec mode: store all entries
|
||||
int64_t required_size = start_logical + total_count;
|
||||
if (static_cast<int64_t>(l2p_vec_.size()) < required_size) {
|
||||
l2p_vec_.resize(required_size, -1);
|
||||
}
|
||||
|
||||
int64_t physical_idx = start_physical;
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
if (valid_data[i]) {
|
||||
l2p_vec_[start_logical + i] = physical_idx;
|
||||
if (physical_idx >= static_cast<int64_t>(p2l_vec_.size())) {
|
||||
p2l_vec_.resize(physical_idx + 1, -1);
|
||||
}
|
||||
p2l_vec_[physical_idx] = start_logical + i;
|
||||
physical_idx++;
|
||||
} else {
|
||||
l2p_vec_[start_logical + i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
valid_count_ += valid_count;
|
||||
}
|
||||
|
||||
void
|
||||
OffsetMapping::BuildIncremental(const bool* valid_data,
|
||||
int64_t count,
|
||||
int64_t start_logical,
|
||||
int64_t start_physical) {
|
||||
if (count == 0 || valid_data == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock<std::shared_mutex> lck(mutex_);
|
||||
enabled_ = true;
|
||||
total_count_ = start_logical + count;
|
||||
|
||||
// Incremental builds always use vec mode
|
||||
if (use_map_ && !l2p_map_.empty()) {
|
||||
// Convert from map to vec if needed
|
||||
int64_t max_logical = 0;
|
||||
for (const auto& [logical, physical] : l2p_map_) {
|
||||
if (logical > max_logical) {
|
||||
max_logical = logical;
|
||||
}
|
||||
}
|
||||
l2p_vec_.resize(max_logical + 1, -1);
|
||||
for (const auto& [logical, physical] : l2p_map_) {
|
||||
l2p_vec_[logical] = physical;
|
||||
}
|
||||
int64_t max_physical = 0;
|
||||
for (const auto& [physical, logical] : p2l_map_) {
|
||||
if (physical > max_physical) {
|
||||
max_physical = physical;
|
||||
}
|
||||
}
|
||||
p2l_vec_.resize(max_physical + 1, -1);
|
||||
for (const auto& [physical, logical] : p2l_map_) {
|
||||
p2l_vec_[physical] = logical;
|
||||
}
|
||||
l2p_map_.clear();
|
||||
p2l_map_.clear();
|
||||
use_map_ = false;
|
||||
}
|
||||
|
||||
// Resize l2p_vec if needed
|
||||
int64_t required_size = start_logical + count;
|
||||
if (static_cast<int64_t>(l2p_vec_.size()) < required_size) {
|
||||
l2p_vec_.resize(required_size, -1);
|
||||
}
|
||||
|
||||
int64_t physical_idx = start_physical;
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
if (valid_data[i]) {
|
||||
l2p_vec_[start_logical + i] = physical_idx;
|
||||
if (physical_idx >= static_cast<int64_t>(p2l_vec_.size())) {
|
||||
p2l_vec_.resize(physical_idx + 1, -1);
|
||||
}
|
||||
p2l_vec_[physical_idx] = start_logical + i;
|
||||
physical_idx++;
|
||||
valid_count_++;
|
||||
} else {
|
||||
l2p_vec_[start_logical + i] = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t
|
||||
OffsetMapping::GetPhysicalOffset(int64_t logical_offset) const {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
if (!enabled_) {
|
||||
return logical_offset;
|
||||
}
|
||||
if (use_map_) {
|
||||
auto it = l2p_map_.find(static_cast<int32_t>(logical_offset));
|
||||
if (it != l2p_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
if (logical_offset < static_cast<int64_t>(l2p_vec_.size())) {
|
||||
return l2p_vec_[logical_offset];
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64_t
|
||||
OffsetMapping::GetLogicalOffset(int64_t physical_offset) const {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
if (!enabled_) {
|
||||
return physical_offset;
|
||||
}
|
||||
if (use_map_) {
|
||||
auto it = p2l_map_.find(static_cast<int32_t>(physical_offset));
|
||||
if (it != p2l_map_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
if (physical_offset < static_cast<int64_t>(p2l_vec_.size())) {
|
||||
return p2l_vec_[physical_offset];
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool
|
||||
OffsetMapping::IsValid(int64_t logical_offset) const {
|
||||
return GetPhysicalOffset(logical_offset) >= 0;
|
||||
}
|
||||
|
||||
int64_t
|
||||
OffsetMapping::GetValidCount() const {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
return valid_count_;
|
||||
}
|
||||
|
||||
bool
|
||||
OffsetMapping::IsEnabled() const {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
return enabled_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
OffsetMapping::GetNextPhysicalOffset() const {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
return valid_count_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
OffsetMapping::GetTotalCount() const {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
return total_count_;
|
||||
}
|
||||
|
||||
} // namespace milvus
|
||||
96
internal/core/src/common/OffsetMapping.h
Normal file
96
internal/core/src/common/OffsetMapping.h
Normal file
@ -0,0 +1,96 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace milvus {
|
||||
|
||||
// Bidirectional offset mapping for nullable vector storage
|
||||
// Maps between logical offsets (with nulls) and physical offsets (only valid data)
|
||||
// Supports two storage modes:
|
||||
// - vec mode: uses vector for both L2P and P2L, efficient when valid ratio >= 10%
|
||||
// - map mode: uses unordered_map for L2P, efficient when valid ratio < 10%
|
||||
class OffsetMapping {
|
||||
public:
|
||||
OffsetMapping() = default;
|
||||
|
||||
// Build mapping from valid_data (bool array format)
|
||||
// If use_vec is not specified, auto-select based on valid ratio (< 10% uses map)
|
||||
void
|
||||
Build(const bool* valid_data,
|
||||
int64_t total_count,
|
||||
int64_t start_logical = 0,
|
||||
int64_t start_physical = 0);
|
||||
|
||||
// Build mapping incrementally (always uses vec mode for incremental builds)
|
||||
void
|
||||
BuildIncremental(const bool* valid_data,
|
||||
int64_t count,
|
||||
int64_t start_logical,
|
||||
int64_t start_physical);
|
||||
|
||||
// Get physical offset from logical offset. Returns -1 if null.
|
||||
int64_t
|
||||
GetPhysicalOffset(int64_t logical_offset) const;
|
||||
|
||||
// Get logical offset from physical offset. Returns -1 if not found.
|
||||
int64_t
|
||||
GetLogicalOffset(int64_t physical_offset) const;
|
||||
|
||||
// Check if a logical offset is valid (not null)
|
||||
bool
|
||||
IsValid(int64_t logical_offset) const;
|
||||
|
||||
// Get count of valid (non-null) elements
|
||||
int64_t
|
||||
GetValidCount() const;
|
||||
|
||||
// Check if mapping is enabled
|
||||
bool
|
||||
IsEnabled() const;
|
||||
|
||||
// Get next physical offset (for incremental builds)
|
||||
int64_t
|
||||
GetNextPhysicalOffset() const;
|
||||
|
||||
// Get total logical count (including nulls)
|
||||
int64_t
|
||||
GetTotalCount() const;
|
||||
|
||||
private:
|
||||
bool enabled_{false};
|
||||
bool use_map_{false}; // true: use map for L2P, false: use vec
|
||||
|
||||
// Vec mode storage (uses int32_t to save memory)
|
||||
std::vector<int32_t> l2p_vec_; // logical -> physical, -1 means null
|
||||
std::vector<int32_t> p2l_vec_; // physical -> logical
|
||||
|
||||
// Map mode storage (for sparse valid data)
|
||||
std::unordered_map<int32_t, int32_t> l2p_map_; // logical -> physical
|
||||
std::unordered_map<int32_t, int32_t> p2l_map_; // physical -> logical
|
||||
|
||||
int64_t valid_count_{0};
|
||||
int64_t total_count_{0}; // total logical count (including nulls)
|
||||
mutable std::shared_mutex mutex_;
|
||||
};
|
||||
|
||||
} // namespace milvus
|
||||
@ -29,6 +29,8 @@
|
||||
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/ArrayOffsets.h"
|
||||
#include "common/OffsetMapping.h"
|
||||
#include "query/Utils.h"
|
||||
#include "pb/schema.pb.h"
|
||||
#include "knowhere/index/index_node.h"
|
||||
|
||||
@ -156,9 +158,10 @@ class VectorIterator {
|
||||
class ChunkMergeIterator : public VectorIterator {
|
||||
public:
|
||||
ChunkMergeIterator(int chunk_count,
|
||||
const milvus::OffsetMapping& offset_mapping,
|
||||
const std::vector<int64_t>& total_rows_until_chunk = {},
|
||||
bool larger_is_closer = false)
|
||||
: total_rows_until_chunk_(total_rows_until_chunk),
|
||||
: offset_mapping_(&offset_mapping),
|
||||
larger_is_closer_(larger_is_closer),
|
||||
heap_(OffsetDisPairComparator(larger_is_closer)) {
|
||||
iterators_.reserve(chunk_count);
|
||||
@ -180,7 +183,11 @@ class ChunkMergeIterator : public VectorIterator {
|
||||
origin_pair, top->GetIteratorIdx());
|
||||
heap_.push(off_dis_pair);
|
||||
}
|
||||
return top->GetOffDis();
|
||||
auto result = top->GetOffDis();
|
||||
if (offset_mapping_ != nullptr) {
|
||||
result.first = offset_mapping_->GetLogicalOffset(result.first);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -231,6 +238,7 @@ class ChunkMergeIterator : public VectorIterator {
|
||||
OffsetDisPairComparator>
|
||||
heap_;
|
||||
bool sealed = false;
|
||||
const milvus::OffsetMapping* offset_mapping_ = nullptr;
|
||||
std::vector<int64_t> total_rows_until_chunk_;
|
||||
bool larger_is_closer_ = false;
|
||||
//currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future
|
||||
@ -258,6 +266,7 @@ struct SearchResult {
|
||||
int chunk_count,
|
||||
const std::vector<int64_t>& total_rows_until_chunk,
|
||||
const std::vector<knowhere::IndexNode::IteratorPtr>& kw_iterators,
|
||||
const milvus::OffsetMapping& offset_mapping,
|
||||
bool larger_is_closer = false) {
|
||||
AssertInfo(kw_iterators.size() == nq * chunk_count,
|
||||
"kw_iterators count:{} is not equal to nq*chunk_count:{}, "
|
||||
@ -269,8 +278,11 @@ struct SearchResult {
|
||||
for (int i = 0, vec_iter_idx = 0; i < kw_iterators.size(); i++) {
|
||||
vec_iter_idx = vec_iter_idx % nq;
|
||||
if (vector_iterators.size() < nq) {
|
||||
auto chunk_merge_iter = std::make_shared<ChunkMergeIterator>(
|
||||
chunk_count, total_rows_until_chunk, larger_is_closer);
|
||||
auto chunk_merge_iter =
|
||||
std::make_shared<ChunkMergeIterator>(chunk_count,
|
||||
offset_mapping,
|
||||
total_rows_until_chunk,
|
||||
larger_is_closer);
|
||||
vector_iterators.emplace_back(chunk_merge_iter);
|
||||
}
|
||||
const auto& kw_iterator = kw_iterators[i];
|
||||
|
||||
@ -95,7 +95,8 @@ class Schema {
|
||||
AddDebugField(const std::string& name,
|
||||
DataType data_type,
|
||||
int64_t dim,
|
||||
std::optional<knowhere::MetricType> metric_type) {
|
||||
std::optional<knowhere::MetricType> metric_type,
|
||||
bool nullable = false) {
|
||||
auto field_id = FieldId(debug_id);
|
||||
debug_id++;
|
||||
auto field_meta = FieldMeta(FieldName(name),
|
||||
@ -103,7 +104,7 @@ class Schema {
|
||||
data_type,
|
||||
dim,
|
||||
metric_type,
|
||||
false,
|
||||
nullable,
|
||||
std::nullopt);
|
||||
this->AddField(std::move(field_meta));
|
||||
return field_id;
|
||||
@ -225,7 +226,7 @@ class Schema {
|
||||
std::optional<knowhere::MetricType> metric_type,
|
||||
bool nullable) {
|
||||
auto field_meta = FieldMeta(
|
||||
name, id, data_type, dim, metric_type, false, std::nullopt);
|
||||
name, id, data_type, dim, metric_type, nullable, std::nullopt);
|
||||
this->AddField(std::move(field_meta));
|
||||
}
|
||||
|
||||
|
||||
@ -304,7 +304,9 @@ CopyAndWrapSparseRow(const void* data,
|
||||
template <typename Iterable>
|
||||
std::unique_ptr<knowhere::sparse::SparseRow<SparseValueType>[]>
|
||||
SparseBytesToRows(const Iterable& rows, const bool validate = false) {
|
||||
AssertInfo(rows.size() > 0, "at least 1 sparse row should be provided");
|
||||
if (rows.size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
auto res = std::make_unique<knowhere::sparse::SparseRow<SparseValueType>[]>(
|
||||
rows.size());
|
||||
for (size_t i = 0; i < rows.size(); ++i) {
|
||||
|
||||
@ -57,7 +57,12 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info,
|
||||
bool larger_is_closer =
|
||||
PositivelyRelated(search_info.metric_type_);
|
||||
search_result.AssembleChunkVectorIterators(
|
||||
nq, 1, {0}, iterators_val.value(), larger_is_closer);
|
||||
nq,
|
||||
1,
|
||||
{0},
|
||||
iterators_val.value(),
|
||||
index.GetOffsetMapping(),
|
||||
larger_is_closer);
|
||||
} else {
|
||||
std::string operator_type = "";
|
||||
if (search_info.group_by_field_id_.has_value()) {
|
||||
|
||||
@ -120,6 +120,26 @@ VectorDiskAnnIndex<T>::Load(milvus::tracer::TraceContext ctx,
|
||||
"failed to Deserialize index, " + KnowhereStatusString(stat));
|
||||
span_load_engine->End();
|
||||
|
||||
auto local_chunk_manager =
|
||||
storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager();
|
||||
auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
|
||||
|
||||
auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY;
|
||||
if (local_chunk_manager->Exist(valid_data_path)) {
|
||||
size_t count;
|
||||
local_chunk_manager->Read(valid_data_path, 0, &count, sizeof(size_t));
|
||||
size_t byte_size = (count + 7) / 8;
|
||||
std::vector<uint8_t> valid_bitmap(byte_size);
|
||||
local_chunk_manager->Read(
|
||||
valid_data_path, sizeof(size_t), valid_bitmap.data(), byte_size);
|
||||
// Convert bitmap to bool array
|
||||
std::unique_ptr<bool[]> valid_data(new bool[count]);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
valid_data[i] = (valid_bitmap[i / 8] >> (i % 8)) & 1;
|
||||
}
|
||||
BuildValidData(valid_data.get(), count);
|
||||
}
|
||||
|
||||
SetDim(index_.Dim());
|
||||
}
|
||||
|
||||
@ -298,6 +318,23 @@ VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
|
||||
if (stat != knowhere::Status::success)
|
||||
ThrowInfo(ErrorCode::IndexBuildError,
|
||||
"failed to build index, " + KnowhereStatusString(stat));
|
||||
|
||||
if (HasValidData()) {
|
||||
auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY;
|
||||
size_t count = offset_mapping_.GetTotalCount();
|
||||
local_chunk_manager->Write(valid_data_path, 0, &count, sizeof(size_t));
|
||||
size_t byte_size = (count + 7) / 8;
|
||||
std::vector<uint8_t> packed_data(byte_size, 0);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
if (offset_mapping_.IsValid(i)) {
|
||||
packed_data[i / 8] |= (1 << (i % 8));
|
||||
}
|
||||
}
|
||||
local_chunk_manager->Write(
|
||||
valid_data_path, sizeof(size_t), packed_data.data(), byte_size);
|
||||
file_manager_->AddFile(valid_data_path);
|
||||
}
|
||||
|
||||
local_chunk_manager->RemoveDir(
|
||||
storage::GetSegmentRawDataPathPrefix(local_chunk_manager, segment_id));
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
#include "index/Index.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/OffsetMapping.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/OpContext.h"
|
||||
@ -34,6 +35,10 @@
|
||||
|
||||
namespace milvus::index {
|
||||
|
||||
// valid data keys for nullable vector index serialization
|
||||
constexpr const char* VALID_DATA_KEY = "valid_data";
|
||||
constexpr const char* VALID_DATA_COUNT_KEY = "valid_data_count";
|
||||
|
||||
class VectorIndex : public IndexBase {
|
||||
public:
|
||||
explicit VectorIndex(const IndexType& index_type,
|
||||
@ -145,6 +150,56 @@ class VectorIndex : public IndexBase {
|
||||
return search_cfg;
|
||||
}
|
||||
|
||||
void
|
||||
UpdateValidData(const bool* valid_data, int64_t count) {
|
||||
offset_mapping_.BuildIncremental(
|
||||
valid_data,
|
||||
count,
|
||||
offset_mapping_.GetTotalCount(),
|
||||
offset_mapping_.GetNextPhysicalOffset());
|
||||
}
|
||||
|
||||
void
|
||||
BuildValidData(const bool* valid_data, int64_t total_count) {
|
||||
offset_mapping_.Build(valid_data, total_count);
|
||||
}
|
||||
|
||||
bool
|
||||
IsRowValid(int64_t logical_offset) const {
|
||||
if (!offset_mapping_.IsEnabled()) {
|
||||
return true;
|
||||
}
|
||||
return offset_mapping_.IsValid(logical_offset);
|
||||
}
|
||||
|
||||
bool
|
||||
HasValidData() const {
|
||||
return offset_mapping_.IsEnabled();
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetValidCount() const {
|
||||
return offset_mapping_.GetValidCount();
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetPhysicalOffset(int64_t logical_offset) const {
|
||||
return offset_mapping_.GetPhysicalOffset(logical_offset);
|
||||
}
|
||||
|
||||
int64_t
|
||||
GetLogicalOffset(int64_t physical_offset) const {
|
||||
return offset_mapping_.GetLogicalOffset(physical_offset);
|
||||
}
|
||||
|
||||
const milvus::OffsetMapping&
|
||||
GetOffsetMapping() const {
|
||||
return offset_mapping_;
|
||||
}
|
||||
|
||||
protected:
|
||||
milvus::OffsetMapping offset_mapping_;
|
||||
|
||||
private:
|
||||
MetricType metric_type_;
|
||||
int64_t dim_;
|
||||
|
||||
@ -146,6 +146,27 @@ VectorMemIndex<T>::Serialize(const Config& config) {
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"failed to serialize index: {}",
|
||||
KnowhereStatusString(stat));
|
||||
|
||||
// Serialize valid_data from offset_mapping if enabled
|
||||
if (offset_mapping_.IsEnabled()) {
|
||||
auto total_count = offset_mapping_.GetTotalCount();
|
||||
|
||||
std::shared_ptr<uint8_t[]> count_buf(new uint8_t[sizeof(size_t)]);
|
||||
size_t count = static_cast<size_t>(total_count);
|
||||
std::memcpy(count_buf.get(), &count, sizeof(size_t));
|
||||
ret.Append(VALID_DATA_COUNT_KEY, count_buf, sizeof(size_t));
|
||||
|
||||
size_t byte_size = (count + 7) / 8;
|
||||
std::shared_ptr<uint8_t[]> data(new uint8_t[byte_size]);
|
||||
std::memset(data.get(), 0, byte_size);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
if (offset_mapping_.IsValid(i)) {
|
||||
data[i / 8] |= (1 << (i % 8));
|
||||
}
|
||||
}
|
||||
ret.Append(VALID_DATA_KEY, data, byte_size);
|
||||
}
|
||||
|
||||
Disassemble(ret);
|
||||
|
||||
return ret;
|
||||
@ -160,6 +181,25 @@ VectorMemIndex<T>::LoadWithoutAssemble(const BinarySet& binary_set,
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
"failed to Deserialize index: {}",
|
||||
KnowhereStatusString(stat));
|
||||
|
||||
// Deserialize valid_data bitmap and rebuild offset_mapping
|
||||
if (binary_set.Contains(VALID_DATA_COUNT_KEY) &&
|
||||
binary_set.Contains(VALID_DATA_KEY)) {
|
||||
knowhere::BinaryPtr ptr;
|
||||
ptr = binary_set.GetByName(VALID_DATA_COUNT_KEY);
|
||||
size_t count;
|
||||
std::memcpy(&count, ptr->data.get(), sizeof(size_t));
|
||||
|
||||
ptr = binary_set.GetByName(VALID_DATA_KEY);
|
||||
// Convert bitmap to bool array
|
||||
std::unique_ptr<bool[]> valid_data(new bool[count]);
|
||||
auto bitmap = ptr->data.get();
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
valid_data[i] = (bitmap[i / 8] >> (i % 8)) & 1;
|
||||
}
|
||||
BuildValidData(valid_data.get(), count);
|
||||
}
|
||||
|
||||
SetDim(index_.Dim());
|
||||
}
|
||||
|
||||
@ -339,19 +379,48 @@ VectorMemIndex<T>::Build(const Config& config) {
|
||||
build_config.update(config);
|
||||
build_config.erase(INSERT_FILES_KEY);
|
||||
build_config.erase(VEC_OPT_FIELDS);
|
||||
if (!IndexIsSparse(GetIndexType())) {
|
||||
int64_t total_size = 0;
|
||||
int64_t total_num_rows = 0;
|
||||
int64_t dim = 0;
|
||||
for (auto data : field_datas) {
|
||||
total_size += data->Size();
|
||||
total_num_rows += data->get_num_rows();
|
||||
|
||||
bool nullable = false;
|
||||
int64_t total_valid_rows = 0;
|
||||
int64_t total_num_rows = 0;
|
||||
for (auto data : field_datas) {
|
||||
auto num_rows = data->get_num_rows();
|
||||
auto valid_rows = data->get_valid_rows();
|
||||
total_valid_rows += valid_rows;
|
||||
total_num_rows += num_rows;
|
||||
if (data->IsNullable()) {
|
||||
nullable = true;
|
||||
}
|
||||
}
|
||||
std::unique_ptr<bool[]> valid_data;
|
||||
if (nullable) {
|
||||
valid_data.reset(new bool[total_num_rows]);
|
||||
int64_t chunk_offset = 0;
|
||||
for (auto data : field_datas) {
|
||||
auto rows = data->get_num_rows();
|
||||
// Copy valid data from FieldData (bitmap format to bool array)
|
||||
auto src_bitmap = data->ValidData();
|
||||
for (int64_t i = 0; i < rows; ++i) {
|
||||
valid_data[chunk_offset + i] =
|
||||
(src_bitmap[i >> 3] >> (i & 7)) & 1;
|
||||
}
|
||||
chunk_offset += rows;
|
||||
}
|
||||
}
|
||||
|
||||
if (!IndexIsSparse(GetIndexType())) {
|
||||
int64_t dim = 0;
|
||||
int64_t total_size = 0;
|
||||
for (auto data : field_datas) {
|
||||
AssertInfo(dim == 0 || dim == data->get_dim(),
|
||||
"inconsistent dim value between field datas!");
|
||||
dim = data->get_dim();
|
||||
if (elem_type_ == DataType::NONE) {
|
||||
total_size += data->DataSize();
|
||||
} else {
|
||||
total_size += data->Size();
|
||||
}
|
||||
}
|
||||
|
||||
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
|
||||
|
||||
size_t lim_offset = 0;
|
||||
@ -362,8 +431,9 @@ VectorMemIndex<T>::Build(const Config& config) {
|
||||
if (elem_type_ == DataType::NONE) {
|
||||
// TODO: avoid copying
|
||||
for (auto data : field_datas) {
|
||||
std::memcpy(buf.get() + offset, data->Data(), data->Size());
|
||||
offset += data->Size();
|
||||
auto valid_size = data->DataSize();
|
||||
std::memcpy(buf.get() + offset, data->Data(), valid_size);
|
||||
offset += valid_size;
|
||||
data.reset();
|
||||
}
|
||||
} else {
|
||||
@ -396,12 +466,12 @@ VectorMemIndex<T>::Build(const Config& config) {
|
||||
data.reset();
|
||||
}
|
||||
|
||||
total_num_rows = lim_offset;
|
||||
total_valid_rows = lim_offset;
|
||||
}
|
||||
|
||||
field_datas.clear();
|
||||
|
||||
auto dataset = GenDataset(total_num_rows, dim, buf.get());
|
||||
auto dataset = GenDataset(total_valid_rows, dim, buf.get());
|
||||
if (!scalar_info.empty()) {
|
||||
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
|
||||
}
|
||||
@ -410,12 +480,13 @@ VectorMemIndex<T>::Build(const Config& config) {
|
||||
const_cast<const size_t*>(offsets.data()));
|
||||
}
|
||||
BuildWithDataset(dataset, build_config);
|
||||
if (nullable) {
|
||||
BuildValidData(valid_data.get(), total_num_rows);
|
||||
}
|
||||
} else {
|
||||
// sparse
|
||||
int64_t total_rows = 0;
|
||||
int64_t dim = 0;
|
||||
for (auto field_data : field_datas) {
|
||||
total_rows += field_data->Length();
|
||||
dim = std::max(
|
||||
dim,
|
||||
std::dynamic_pointer_cast<FieldData<SparseFloatVector>>(
|
||||
@ -423,28 +494,31 @@ VectorMemIndex<T>::Build(const Config& config) {
|
||||
->Dim());
|
||||
}
|
||||
std::vector<knowhere::sparse::SparseRow<SparseValueType>> vec(
|
||||
total_rows);
|
||||
total_valid_rows);
|
||||
int64_t offset = 0;
|
||||
for (auto field_data : field_datas) {
|
||||
auto ptr = static_cast<
|
||||
const knowhere::sparse::SparseRow<SparseValueType>*>(
|
||||
field_data->Data());
|
||||
AssertInfo(ptr, "failed to cast field data to sparse rows");
|
||||
for (size_t i = 0; i < field_data->Length(); ++i) {
|
||||
for (size_t i = 0; i < field_data->get_valid_rows(); ++i) {
|
||||
// this does a deep copy of field_data's data.
|
||||
// TODO: avoid copying by enforcing field data to give up
|
||||
// ownership.
|
||||
AssertInfo(dim >= ptr[i].dim(), "bad dim");
|
||||
dim = std::max(dim, static_cast<int64_t>(ptr[i].dim()));
|
||||
vec[offset + i] = ptr[i];
|
||||
}
|
||||
offset += field_data->Length();
|
||||
offset += field_data->get_valid_rows();
|
||||
}
|
||||
auto dataset = GenDataset(total_rows, dim, vec.data());
|
||||
auto dataset = GenDataset(total_valid_rows, dim, vec.data());
|
||||
dataset->SetIsSparse(true);
|
||||
if (!scalar_info.empty()) {
|
||||
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
|
||||
}
|
||||
BuildWithDataset(dataset, build_config);
|
||||
if (nullable) {
|
||||
BuildValidData(valid_data.get(), total_num_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -572,6 +646,10 @@ VectorMemIndex<T>::GetVector(const DatasetPtr dataset) const {
|
||||
template <typename T>
|
||||
std::unique_ptr<const knowhere::sparse::SparseRow<SparseValueType>[]>
|
||||
VectorMemIndex<T>::GetSparseVector(const DatasetPtr dataset) const {
|
||||
if (dataset->GetRows() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res = index_.GetVectorByIds(dataset);
|
||||
if (!res.has_value()) {
|
||||
ThrowInfo(ErrorCode::UnexpectedError,
|
||||
@ -646,6 +724,8 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
||||
LOG_INFO("load with slice meta: {}", !slice_meta_filepath.empty());
|
||||
std::chrono::duration<double> load_duration_sum;
|
||||
std::chrono::duration<double> write_disk_duration_sum;
|
||||
std::unique_ptr<storage::DataCodec> valid_data_count_codec;
|
||||
std::unique_ptr<storage::DataCodec> valid_data_codec;
|
||||
// load files in two parts:
|
||||
// 1. EMB_LIST_META: Written separately to embedding_list_meta_writer_ptr (if embedding list type)
|
||||
// 2. All other binaries: Merged and written to file_writer, forming a unified index file for knowhere
|
||||
@ -683,6 +763,10 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
||||
embedding_list_meta_writer_ptr) {
|
||||
embedding_list_meta_writer_ptr->Write(
|
||||
data->PayloadData(), data->PayloadSize());
|
||||
} else if (prefix == VALID_DATA_COUNT_KEY) {
|
||||
valid_data_count_codec = std::move(data);
|
||||
} else if (prefix == VALID_DATA_KEY) {
|
||||
valid_data_codec = std::move(data);
|
||||
} else {
|
||||
file_writer.Write(data->PayloadData(),
|
||||
data->PayloadSize());
|
||||
@ -724,6 +808,10 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
||||
embedding_list_meta_writer_ptr) {
|
||||
embedding_list_meta_writer_ptr->Write(
|
||||
index_data->PayloadData(), index_data->PayloadSize());
|
||||
} else if (prefix == VALID_DATA_COUNT_KEY) {
|
||||
valid_data_count_codec = std::move(index_data);
|
||||
} else if (prefix == VALID_DATA_KEY) {
|
||||
valid_data_codec = std::move(index_data);
|
||||
} else {
|
||||
file_writer.Write(index_data->PayloadData(),
|
||||
index_data->PayloadSize());
|
||||
@ -768,6 +856,20 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
|
||||
auto dim = index_.Dim();
|
||||
this->SetDim(index_.Dim());
|
||||
|
||||
// Restore valid_data for nullable vector support
|
||||
if (valid_data_count_codec && valid_data_codec) {
|
||||
size_t count;
|
||||
std::memcpy(
|
||||
&count, valid_data_count_codec->PayloadData(), sizeof(size_t));
|
||||
|
||||
std::unique_ptr<bool[]> valid_data(new bool[count]);
|
||||
auto bitmap = valid_data_codec->PayloadData();
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
valid_data[i] = (bitmap[i / 8] >> (i % 8)) & 1;
|
||||
}
|
||||
BuildValidData(valid_data.get(), count);
|
||||
}
|
||||
|
||||
this->mmap_file_raii_ =
|
||||
std::make_unique<MmapFileRAII>(local_filepath.value());
|
||||
LOG_INFO(
|
||||
|
||||
@ -22,7 +22,9 @@ class IndexCreatorBase {
|
||||
virtual ~IndexCreatorBase() = default;
|
||||
|
||||
virtual void
|
||||
Build(const milvus::DatasetPtr& dataset) = 0;
|
||||
Build(const milvus::DatasetPtr& dataset,
|
||||
const bool* valid_data = nullptr,
|
||||
const int64_t valid_data_len = 0) = 0;
|
||||
|
||||
virtual void
|
||||
Build() = 0;
|
||||
|
||||
@ -83,7 +83,11 @@ ScalarIndexCreator::ScalarIndexCreator(
|
||||
}
|
||||
|
||||
void
|
||||
ScalarIndexCreator::Build(const milvus::DatasetPtr& dataset) {
|
||||
ScalarIndexCreator::Build(const milvus::DatasetPtr& dataset,
|
||||
const bool* valid_data,
|
||||
const int64_t valid_data_len) {
|
||||
(void)valid_data;
|
||||
(void)valid_data_len;
|
||||
auto size = dataset->GetRows();
|
||||
auto data = dataset->GetTensor();
|
||||
index_->BuildWithRawDataForUT(size, data);
|
||||
|
||||
@ -27,7 +27,9 @@ class ScalarIndexCreator : public IndexCreatorBase {
|
||||
const storage::FileManagerContext& file_manager_context);
|
||||
|
||||
void
|
||||
Build(const milvus::DatasetPtr& dataset) override;
|
||||
Build(const milvus::DatasetPtr& dataset,
|
||||
const bool* valid_data = nullptr,
|
||||
const int64_t valid_data_len = 0) override;
|
||||
|
||||
void
|
||||
Build() override;
|
||||
|
||||
@ -65,8 +65,15 @@ VecIndexCreator::dim() {
|
||||
}
|
||||
|
||||
void
|
||||
VecIndexCreator::Build(const milvus::DatasetPtr& dataset) {
|
||||
VecIndexCreator::Build(const milvus::DatasetPtr& dataset,
|
||||
const bool* valid_data,
|
||||
const int64_t valid_data_len) {
|
||||
index_->BuildWithDataset(dataset, config_);
|
||||
if (valid_data && valid_data_len > 0) {
|
||||
auto vec_index = dynamic_cast<index::VectorIndex*>(index_.get());
|
||||
AssertInfo(vec_index != nullptr, "failed to cast index to VectorIndex");
|
||||
vec_index->BuildValidData(valid_data, valid_data_len);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
|
||||
@ -39,7 +39,9 @@ class VecIndexCreator : public IndexCreatorBase {
|
||||
const storage::FileManagerContext& file_manager_context);
|
||||
|
||||
void
|
||||
Build(const milvus::DatasetPtr& dataset) override;
|
||||
Build(const milvus::DatasetPtr& dataset,
|
||||
const bool* valid_data = nullptr,
|
||||
const int64_t valid_data_len = 0) override;
|
||||
|
||||
void
|
||||
Build() override;
|
||||
|
||||
@ -564,6 +564,35 @@ BuildFloatVecIndex(CIndex index,
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildFloatVecIndexWithValidData(CIndex index,
|
||||
int64_t float_value_num,
|
||||
const float* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index,
|
||||
"failed to build float vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto dim = cIndex->dim();
|
||||
auto row_nums = float_value_num / dim;
|
||||
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
|
||||
cIndex->Build(ds, valid_data, valid_data_len);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildFloat16VecIndex(CIndex index,
|
||||
int64_t float16_value_num,
|
||||
@ -592,6 +621,36 @@ BuildFloat16VecIndex(CIndex index,
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildFloat16VecIndexWithValidData(CIndex index,
|
||||
int64_t float16_value_num,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(
|
||||
index,
|
||||
"failed to build float16 vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto dim = cIndex->dim();
|
||||
auto row_nums = float16_value_num / dim / 2;
|
||||
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
|
||||
cIndex->Build(ds, valid_data, valid_data_len);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildBFloat16VecIndex(CIndex index,
|
||||
int64_t bfloat16_value_num,
|
||||
@ -620,6 +679,36 @@ BuildBFloat16VecIndex(CIndex index,
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildBFloat16VecIndexWithValidData(CIndex index,
|
||||
int64_t bfloat16_value_num,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(
|
||||
index,
|
||||
"failed to build bfloat16 vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto dim = cIndex->dim();
|
||||
auto row_nums = bfloat16_value_num / dim / 2;
|
||||
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
|
||||
cIndex->Build(ds, valid_data, valid_data_len);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
@ -646,6 +735,36 @@ BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) {
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildBinaryVecIndexWithValidData(CIndex index,
|
||||
int64_t data_size,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(
|
||||
index,
|
||||
"failed to build binary vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto dim = cIndex->dim();
|
||||
auto row_nums = (data_size * 8) / dim;
|
||||
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
|
||||
cIndex->Build(ds, valid_data, valid_data_len);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildSparseFloatVecIndex(CIndex index,
|
||||
int64_t row_num,
|
||||
@ -674,6 +793,36 @@ BuildSparseFloatVecIndex(CIndex index,
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildSparseFloatVecIndexWithValidData(CIndex index,
|
||||
int64_t row_num,
|
||||
int64_t dim,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(
|
||||
index,
|
||||
"failed to build sparse float vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto ds = knowhere::GenDataSet(row_num, dim, vectors);
|
||||
ds->SetIsSparse(true);
|
||||
cIndex->Build(ds, valid_data, valid_data_len);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildInt8VecIndex(CIndex index, int64_t int8_value_num, const int8_t* vectors) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
@ -699,6 +848,35 @@ BuildInt8VecIndex(CIndex index, int64_t int8_value_num, const int8_t* vectors) {
|
||||
return status;
|
||||
}
|
||||
|
||||
CStatus
|
||||
BuildInt8VecIndexWithValidData(CIndex index,
|
||||
int64_t int8_value_num,
|
||||
const int8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
|
||||
auto status = CStatus();
|
||||
try {
|
||||
AssertInfo(index,
|
||||
"failed to build int8 vector index, passed index was null");
|
||||
auto real_index =
|
||||
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
|
||||
auto cIndex =
|
||||
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
|
||||
auto dim = cIndex->dim();
|
||||
auto row_nums = int8_value_num / dim;
|
||||
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
|
||||
cIndex->Build(ds, valid_data, valid_data_len);
|
||||
status.error_code = Success;
|
||||
status.error_msg = "";
|
||||
} catch (std::exception& e) {
|
||||
status.error_code = UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
// field_data:
|
||||
// 1, serialized proto::schema::BoolArray, if type is bool;
|
||||
// 2, serialized proto::schema::StringArray, if type is string;
|
||||
|
||||
@ -55,24 +55,67 @@ CreateIndexForUT(enum CDataType dtype,
|
||||
CStatus
|
||||
BuildFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors);
|
||||
|
||||
CStatus
|
||||
BuildFloatVecIndexWithValidData(CIndex index,
|
||||
int64_t float_value_num,
|
||||
const float* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len);
|
||||
|
||||
CStatus
|
||||
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors);
|
||||
|
||||
CStatus
|
||||
BuildBinaryVecIndexWithValidData(CIndex index,
|
||||
int64_t data_size,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len);
|
||||
|
||||
CStatus
|
||||
BuildFloat16VecIndex(CIndex index, int64_t data_size, const uint8_t* vectors);
|
||||
|
||||
CStatus
|
||||
BuildFloat16VecIndexWithValidData(CIndex index,
|
||||
int64_t data_size,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len);
|
||||
|
||||
CStatus
|
||||
BuildBFloat16VecIndex(CIndex index, int64_t data_size, const uint8_t* vectors);
|
||||
|
||||
CStatus
|
||||
BuildBFloat16VecIndexWithValidData(CIndex index,
|
||||
int64_t data_size,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len);
|
||||
|
||||
CStatus
|
||||
BuildSparseFloatVecIndex(CIndex index,
|
||||
int64_t row_num,
|
||||
int64_t dim,
|
||||
const uint8_t* vectors);
|
||||
|
||||
CStatus
|
||||
BuildSparseFloatVecIndexWithValidData(CIndex index,
|
||||
int64_t row_num,
|
||||
int64_t dim,
|
||||
const uint8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len);
|
||||
|
||||
CStatus
|
||||
BuildInt8VecIndex(CIndex index, int64_t data_size, const int8_t* vectors);
|
||||
|
||||
CStatus
|
||||
BuildInt8VecIndexWithValidData(CIndex index,
|
||||
int64_t data_size,
|
||||
const int8_t* vectors,
|
||||
const bool* valid_data,
|
||||
int64_t valid_data_len);
|
||||
|
||||
// field_data:
|
||||
// 1, serialized proto::schema::BoolArray, if type is bool;
|
||||
// 2, serialized proto::schema::StringArray, if type is string;
|
||||
|
||||
@ -368,6 +368,33 @@ class ChunkedColumnBase : public ChunkedColumnInterface {
|
||||
return meta->num_rows_until_chunk_;
|
||||
}
|
||||
|
||||
void
|
||||
BuildValidRowIds(milvus::OpContext* op_ctx) override {
|
||||
if (!nullable_) {
|
||||
return;
|
||||
}
|
||||
auto ca = SemiInlineGet(slot_->PinAllCells(op_ctx));
|
||||
int64_t logical_offset = 0;
|
||||
valid_data_.resize(num_rows_);
|
||||
valid_count_per_chunk_.resize(num_chunks_);
|
||||
for (size_t i = 0; i < num_chunks_; i++) {
|
||||
auto chunk = ca->get_cell_of(i);
|
||||
auto rows = chunk_row_nums(i);
|
||||
int64_t valid_count = 0;
|
||||
for (int64_t j = 0; j < rows; j++) {
|
||||
if (chunk->isValid(j)) {
|
||||
valid_data_[logical_offset + j] = true;
|
||||
valid_count++;
|
||||
} else {
|
||||
valid_data_[logical_offset + j] = false;
|
||||
}
|
||||
}
|
||||
valid_count_per_chunk_[i] = valid_count;
|
||||
logical_offset += rows;
|
||||
}
|
||||
BuildOffsetMapping();
|
||||
}
|
||||
|
||||
protected:
|
||||
bool nullable_{false};
|
||||
DataType data_type_{DataType::NONE};
|
||||
|
||||
@ -667,6 +667,36 @@ class ProxyChunkColumn : public ChunkedColumnInterface {
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
BuildValidRowIds(milvus::OpContext* op_ctx) override {
|
||||
if (!field_meta_.is_nullable()) {
|
||||
return;
|
||||
}
|
||||
auto total_rows = NumRows();
|
||||
auto total_chunks = num_chunks();
|
||||
valid_data_.resize(total_rows);
|
||||
valid_count_per_chunk_.resize(total_chunks);
|
||||
|
||||
int64_t logical_offset = 0;
|
||||
for (int64_t i = 0; i < total_chunks; i++) {
|
||||
auto group_chunk = group_->GetGroupChunk(op_ctx, i);
|
||||
auto chunk = group_chunk.get()->GetChunk(field_id_);
|
||||
auto rows = chunk->RowNums();
|
||||
int64_t valid_count = 0;
|
||||
for (int64_t j = 0; j < rows; j++) {
|
||||
if (chunk->isValid(j)) {
|
||||
valid_data_[logical_offset + j] = true;
|
||||
valid_count++;
|
||||
} else {
|
||||
valid_data_[logical_offset + j] = false;
|
||||
}
|
||||
}
|
||||
valid_count_per_chunk_[i] = valid_count;
|
||||
logical_offset += rows;
|
||||
}
|
||||
BuildOffsetMapping();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<ChunkedColumnGroup> group_;
|
||||
FieldId field_id_;
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
|
||||
#include "cachinglayer/CacheSlot.h"
|
||||
#include "common/Chunk.h"
|
||||
#include "common/OffsetMapping.h"
|
||||
#include "common/bson_view.h"
|
||||
namespace milvus {
|
||||
|
||||
@ -131,6 +132,35 @@ class ChunkedColumnInterface {
|
||||
virtual const std::vector<int64_t>&
|
||||
GetNumRowsUntilChunk() const = 0;
|
||||
|
||||
const FixedVector<bool>&
|
||||
GetValidData() const {
|
||||
return valid_data_;
|
||||
}
|
||||
|
||||
const std::vector<int64_t>&
|
||||
GetValidCountPerChunk() const {
|
||||
return valid_count_per_chunk_;
|
||||
}
|
||||
|
||||
const OffsetMapping&
|
||||
GetOffsetMapping() const {
|
||||
return offset_mapping_;
|
||||
}
|
||||
|
||||
virtual void
|
||||
BuildValidRowIds(milvus::OpContext* op_ctx) {
|
||||
ThrowInfo(ErrorCode::Unsupported,
|
||||
"BuildValidRowIds not supported for this column type");
|
||||
}
|
||||
|
||||
// Build offset mapping from valid_data
|
||||
void
|
||||
BuildOffsetMapping() {
|
||||
if (!valid_data_.empty()) {
|
||||
offset_mapping_.Build(valid_data_.data(), valid_data_.size());
|
||||
}
|
||||
}
|
||||
|
||||
virtual void
|
||||
BulkValueAt(milvus::OpContext* op_ctx,
|
||||
std::function<void(const char*, size_t)> fn,
|
||||
@ -237,6 +267,10 @@ class ChunkedColumnInterface {
|
||||
}
|
||||
|
||||
protected:
|
||||
FixedVector<bool> valid_data_;
|
||||
std::vector<int64_t> valid_count_per_chunk_;
|
||||
OffsetMapping offset_mapping_;
|
||||
|
||||
std::pair<std::vector<milvus::cachinglayer::cid_t>, std::vector<int64_t>>
|
||||
ToChunkIdAndOffset(const int64_t* offsets, int64_t count) const {
|
||||
AssertInfo(offsets != nullptr, "Offsets cannot be nullptr");
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "query/CachedSearchIterator.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnIndex.h"
|
||||
#include "query/Utils.h"
|
||||
#include "exec/operator/Utils.h"
|
||||
|
||||
namespace milvus::query {
|
||||
@ -82,8 +83,6 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
SearchResult& search_result) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& record = segment.get_insert_record();
|
||||
auto active_row_count =
|
||||
std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
|
||||
|
||||
// step 1.1: get meta
|
||||
// step 1.2: get which vector field to search
|
||||
@ -155,6 +154,19 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
|
||||
// step 3: brute force search where small indexing is unavailable
|
||||
auto vec_ptr = record.get_data_base(vecfield_id);
|
||||
const auto& offset_mapping = vec_ptr->get_offset_mapping();
|
||||
|
||||
TargetBitmap transformed_bitset;
|
||||
BitsetView search_bitset = bitset;
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
transformed_bitset = TransformBitset(bitset, offset_mapping);
|
||||
search_bitset = BitsetView(transformed_bitset);
|
||||
}
|
||||
|
||||
auto active_count = offset_mapping.IsEnabled()
|
||||
? offset_mapping.GetValidCount()
|
||||
: std::min(int64_t(bitset.size()),
|
||||
segment.get_active_count(timestamp));
|
||||
|
||||
if (info.iterator_v2_info_.has_value()) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
@ -163,17 +175,20 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
|
||||
CachedSearchIterator cached_iter(search_dataset,
|
||||
vec_ptr,
|
||||
active_row_count,
|
||||
active_count,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
search_bitset,
|
||||
data_type);
|
||||
cached_iter.NextBatch(info, search_result);
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
TransformOffset(search_result.seg_offsets_, offset_mapping);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
|
||||
auto max_chunk = upper_div(active_row_count, vec_size_per_chunk);
|
||||
auto max_chunk = upper_div(active_count, vec_size_per_chunk);
|
||||
|
||||
// embedding search embedding on embedding list
|
||||
bool embedding_search = false;
|
||||
@ -188,7 +203,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
|
||||
auto row_begin = chunk_id * vec_size_per_chunk;
|
||||
auto row_end =
|
||||
std::min(active_row_count, (chunk_id + 1) * vec_size_per_chunk);
|
||||
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto size_per_chunk = row_end - row_begin;
|
||||
|
||||
query::dataset::RawDataset sub_data;
|
||||
@ -260,7 +275,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
sub_data,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
search_bitset,
|
||||
vector_type);
|
||||
final_qr.merge(sub_qr);
|
||||
} else {
|
||||
@ -268,7 +283,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
sub_data,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
search_bitset,
|
||||
vector_type,
|
||||
element_type,
|
||||
op_context);
|
||||
@ -286,6 +301,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
max_chunk,
|
||||
chunk_rows,
|
||||
final_qr.chunk_iterators(),
|
||||
offset_mapping,
|
||||
larger_is_closer);
|
||||
} else {
|
||||
if (info.array_offsets_ != nullptr) {
|
||||
@ -300,6 +316,9 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
std::move(final_qr.mutable_offsets());
|
||||
}
|
||||
search_result.distances_ = std::move(final_qr.mutable_distances());
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
TransformOffset(search_result.seg_offsets_, offset_mapping);
|
||||
}
|
||||
}
|
||||
search_result.unity_topK_ = topk;
|
||||
search_result.total_nq_ = num_queries;
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include "SearchOnIndex.h"
|
||||
#include "Utils.h"
|
||||
#include "exec/operator/Utils.h"
|
||||
#include "CachedSearchIterator.h"
|
||||
|
||||
@ -28,23 +29,39 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
|
||||
auto dataset =
|
||||
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
|
||||
dataset->SetIsSparse(is_sparse);
|
||||
|
||||
const auto& offset_mapping = indexing.GetOffsetMapping();
|
||||
TargetBitmap transformed_bitset;
|
||||
BitsetView search_bitset = bitset;
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
transformed_bitset = TransformBitset(bitset, offset_mapping);
|
||||
search_bitset = BitsetView(transformed_bitset);
|
||||
}
|
||||
|
||||
if (milvus::exec::PrepareVectorIteratorsFromIndex(search_conf,
|
||||
num_queries,
|
||||
dataset,
|
||||
search_result,
|
||||
bitset,
|
||||
search_bitset,
|
||||
indexing)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (search_conf.iterator_v2_info_.has_value()) {
|
||||
auto iter =
|
||||
CachedSearchIterator(indexing, dataset, search_conf, bitset);
|
||||
CachedSearchIterator(indexing, dataset, search_conf, search_bitset);
|
||||
iter.NextBatch(search_conf, search_result);
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
TransformOffset(search_result.seg_offsets_, offset_mapping);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
indexing.Query(dataset, search_conf, bitset, op_context, search_result);
|
||||
indexing.Query(
|
||||
dataset, search_conf, search_bitset, op_context, search_result);
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
TransformOffset(search_result.seg_offsets_, offset_mapping);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
||||
@ -16,12 +16,14 @@
|
||||
#include "bitset/detail/element_wise.h"
|
||||
#include "cachinglayer/Utils.h"
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/Consts.h"
|
||||
#include "common/QueryInfo.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
#include "query/CachedSearchIterator.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnSealed.h"
|
||||
#include "query/Utils.h"
|
||||
#include "query/helper.h"
|
||||
#include "exec/operator/Utils.h"
|
||||
|
||||
@ -73,21 +75,40 @@ SearchOnSealedIndex(const Schema& schema,
|
||||
auto vec_index =
|
||||
dynamic_cast<index::VectorIndex*>(accessor->get_cell_of(0));
|
||||
|
||||
const auto& offset_mapping = vec_index->GetOffsetMapping();
|
||||
TargetBitmap transformed_bitset;
|
||||
BitsetView search_bitset = bitset;
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
transformed_bitset = TransformBitset(bitset, offset_mapping);
|
||||
search_bitset = BitsetView(transformed_bitset);
|
||||
if (offset_mapping.GetValidCount() == 0) {
|
||||
auto total_num = num_queries * topK;
|
||||
search_result.seg_offsets_.resize(total_num, INVALID_SEG_OFFSET);
|
||||
search_result.distances_.resize(total_num, 0.0f);
|
||||
search_result.total_nq_ = num_queries;
|
||||
search_result.unity_topK_ = topK;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (search_info.iterator_v2_info_.has_value()) {
|
||||
CachedSearchIterator cached_iter(
|
||||
*vec_index, dataset, search_info, bitset);
|
||||
*vec_index, dataset, search_info, search_bitset);
|
||||
cached_iter.NextBatch(search_info, search_result);
|
||||
TransformOffset(search_result.seg_offsets_, offset_mapping);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
|
||||
num_queries,
|
||||
dataset,
|
||||
search_result,
|
||||
bitset,
|
||||
*vec_index)) {
|
||||
bool use_iterator =
|
||||
milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
|
||||
num_queries,
|
||||
dataset,
|
||||
search_result,
|
||||
search_bitset,
|
||||
*vec_index);
|
||||
if (!use_iterator) {
|
||||
vec_index->Query(
|
||||
dataset, search_info, bitset, op_context, search_result);
|
||||
dataset, search_info, search_bitset, op_context, search_result);
|
||||
float* distances = search_result.distances_.data();
|
||||
auto total_num = num_queries * topK;
|
||||
if (round_decimal != -1) {
|
||||
@ -120,6 +141,7 @@ SearchOnSealedIndex(const Schema& schema,
|
||||
search_result.element_level_ = true;
|
||||
}
|
||||
}
|
||||
TransformOffset(search_result.seg_offsets_, offset_mapping);
|
||||
search_result.total_nq_ = num_queries;
|
||||
search_result.unity_topK_ = topK;
|
||||
}
|
||||
@ -185,12 +207,30 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
}
|
||||
|
||||
auto offset = 0;
|
||||
|
||||
const auto& offset_mapping = column->GetOffsetMapping();
|
||||
TargetBitmap transformed_bitset;
|
||||
BitsetView search_bitview = bitview;
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
transformed_bitset = TransformBitset(bitview, offset_mapping);
|
||||
search_bitview = BitsetView(transformed_bitset);
|
||||
if (offset_mapping.GetValidCount() == 0) {
|
||||
auto total_num = num_queries * search_info.topk_;
|
||||
result.seg_offsets_.resize(total_num, INVALID_SEG_OFFSET);
|
||||
result.distances_.resize(total_num, 0.0f);
|
||||
result.total_nq_ = num_queries;
|
||||
result.unity_topK_ = search_info.topk_;
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto vector_chunks = column->GetAllChunks(op_context);
|
||||
const auto& valid_count_per_chunk = column->GetValidCountPerChunk();
|
||||
for (int i = 0; i < num_chunk; ++i) {
|
||||
auto pw = vector_chunks[i];
|
||||
auto vec_data = pw.get()->Data();
|
||||
auto chunk_size = column->chunk_row_nums(i);
|
||||
if (offset_mapping.IsEnabled() && !valid_count_per_chunk.empty()) {
|
||||
chunk_size = valid_count_per_chunk[i];
|
||||
}
|
||||
|
||||
// For element-level search, get element count from VectorArrayOffsets
|
||||
if (is_element_level_search) {
|
||||
@ -221,7 +261,7 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitview,
|
||||
search_bitview,
|
||||
data_type);
|
||||
final_qr.merge(sub_qr);
|
||||
} else {
|
||||
@ -229,7 +269,7 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
raw_dataset,
|
||||
search_info,
|
||||
index_info,
|
||||
bitview,
|
||||
search_bitview,
|
||||
data_type,
|
||||
element_type,
|
||||
op_context);
|
||||
@ -243,6 +283,7 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
num_chunk,
|
||||
column->GetNumRowsUntilChunk(),
|
||||
final_qr.chunk_iterators(),
|
||||
offset_mapping,
|
||||
larger_is_closer);
|
||||
} else {
|
||||
if (search_info.array_offsets_ != nullptr) {
|
||||
@ -256,6 +297,9 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
result.seg_offsets_ = std::move(final_qr.mutable_offsets());
|
||||
}
|
||||
result.distances_ = std::move(final_qr.mutable_distances());
|
||||
if (offset_mapping.IsEnabled()) {
|
||||
TransformOffset(result.seg_offsets_, offset_mapping);
|
||||
}
|
||||
}
|
||||
result.unity_topK_ = query_dataset.topk;
|
||||
result.total_nq_ = query_dataset.num_queries;
|
||||
|
||||
@ -14,9 +14,37 @@
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
#include "common/BitsetView.h"
|
||||
#include "common/OffsetMapping.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
|
||||
namespace milvus::query {
|
||||
inline TargetBitmap
|
||||
TransformBitset(const BitsetView& bitset,
|
||||
const milvus::OffsetMapping& mapping) {
|
||||
TargetBitmap result;
|
||||
auto count = mapping.GetValidCount();
|
||||
result.resize(count);
|
||||
for (int64_t physical_idx = 0; physical_idx < count; physical_idx++) {
|
||||
auto logical_idx = mapping.GetLogicalOffset(physical_idx);
|
||||
if (logical_idx >= 0 &&
|
||||
logical_idx < static_cast<int64_t>(bitset.size())) {
|
||||
result[physical_idx] = bitset.test(logical_idx);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
inline void
|
||||
TransformOffset(std::vector<int64_t>& seg_offsets,
|
||||
const milvus::OffsetMapping& mapping) {
|
||||
for (auto& seg_offset : seg_offsets) {
|
||||
if (seg_offset >= 0) {
|
||||
seg_offset = mapping.GetLogicalOffset(seg_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
inline bool
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -918,6 +918,61 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||
}
|
||||
}
|
||||
|
||||
ChunkedSegmentSealedImpl::ValidResult
|
||||
ChunkedSegmentSealedImpl::FilterVectorValidOffsets(milvus::OpContext* op_ctx,
|
||||
FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const {
|
||||
ValidResult result;
|
||||
result.valid_count = count;
|
||||
|
||||
if (vector_indexings_.is_ready(field_id)) {
|
||||
auto field_indexing = vector_indexings_.get_field_indexing(field_id);
|
||||
auto cache_index = field_indexing->indexing_;
|
||||
auto ca = SemiInlineGet(cache_index->PinCells(op_ctx, {0}));
|
||||
auto vec_index = dynamic_cast<index::VectorIndex*>(ca->get_cell_of(0));
|
||||
|
||||
if (vec_index != nullptr && vec_index->HasValidData()) {
|
||||
result.valid_data = std::make_unique<bool[]>(count);
|
||||
result.valid_offsets.reserve(count);
|
||||
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
bool is_valid = vec_index->IsRowValid(seg_offsets[i]);
|
||||
result.valid_data[i] = is_valid;
|
||||
if (is_valid) {
|
||||
int64_t physical_offset =
|
||||
vec_index->GetPhysicalOffset(seg_offsets[i]);
|
||||
if (physical_offset >= 0) {
|
||||
result.valid_offsets.push_back(physical_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
result.valid_count = result.valid_offsets.size();
|
||||
}
|
||||
} else {
|
||||
auto column = get_column(field_id);
|
||||
if (column != nullptr && column->IsNullable()) {
|
||||
result.valid_data = std::make_unique<bool[]>(count);
|
||||
result.valid_offsets.reserve(count);
|
||||
|
||||
const auto& offset_mapping = column->GetOffsetMapping();
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
bool is_valid = offset_mapping.IsValid(seg_offsets[i]);
|
||||
result.valid_data[i] = is_valid;
|
||||
if (is_valid) {
|
||||
int64_t physical_offset =
|
||||
offset_mapping.GetPhysicalOffset(seg_offsets[i]);
|
||||
if (physical_offset >= 0) {
|
||||
result.valid_offsets.push_back(physical_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
result.valid_count = result.valid_offsets.size();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
ChunkedSegmentSealedImpl::get_vector(milvus::OpContext* op_ctx,
|
||||
FieldId field_id,
|
||||
@ -945,16 +1000,29 @@ ChunkedSegmentSealedImpl::get_vector(milvus::OpContext* op_ctx,
|
||||
|
||||
if (has_raw_data) {
|
||||
// If index has raw data, get vector from memory.
|
||||
auto ids_ds = GenIdsDataset(count, ids);
|
||||
ValidResult filter_result;
|
||||
knowhere::DataSetPtr ids_ds;
|
||||
int64_t valid_count = count;
|
||||
const bool* valid_data = nullptr;
|
||||
if (field_meta.is_nullable()) {
|
||||
filter_result =
|
||||
FilterVectorValidOffsets(op_ctx, field_id, ids, count);
|
||||
ids_ds = GenIdsDataset(filter_result.valid_count,
|
||||
filter_result.valid_offsets.data());
|
||||
valid_count = filter_result.valid_count;
|
||||
valid_data = filter_result.valid_data.get();
|
||||
} else {
|
||||
ids_ds = GenIdsDataset(count, ids);
|
||||
}
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
auto res = vec_index->GetSparseVector(ids_ds);
|
||||
return segcore::CreateVectorDataArrayFrom(
|
||||
res.get(), count, field_meta);
|
||||
res.get(), valid_data, count, valid_count, field_meta);
|
||||
} else {
|
||||
// dense vector:
|
||||
auto vector = vec_index->GetVector(ids_ds);
|
||||
return segcore::CreateVectorDataArrayFrom(
|
||||
vector.data(), count, field_meta);
|
||||
vector.data(), valid_data, count, valid_count, field_meta);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1529,10 +1597,13 @@ ChunkedSegmentSealedImpl::ClearData() {
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
ChunkedSegmentSealedImpl::fill_with_empty(FieldId field_id,
|
||||
int64_t count) const {
|
||||
int64_t count,
|
||||
int64_t valid_count,
|
||||
const void* valid_data) const {
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
if (IsVectorDataType(field_meta.get_data_type())) {
|
||||
return CreateEmptyVectorDataArray(count, field_meta);
|
||||
return CreateEmptyVectorDataArray(
|
||||
count, valid_count, valid_data, field_meta);
|
||||
}
|
||||
return CreateEmptyScalarDataArray(count, field_meta);
|
||||
}
|
||||
@ -1682,8 +1753,22 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
AssertInfo(column != nullptr,
|
||||
"field {} must exist when getting raw data",
|
||||
field_id.get());
|
||||
auto ret = fill_with_empty(field_id, count);
|
||||
if (column->IsNullable()) {
|
||||
|
||||
int64_t valid_count = count;
|
||||
const bool* valid_data = nullptr;
|
||||
const int64_t* valid_offsets = seg_offsets;
|
||||
ValidResult filter_result;
|
||||
|
||||
if (field_meta.is_vector() && field_meta.is_nullable()) {
|
||||
filter_result =
|
||||
FilterVectorValidOffsets(op_ctx, field_id, seg_offsets, count);
|
||||
valid_count = filter_result.valid_count;
|
||||
valid_data = filter_result.valid_data.get();
|
||||
valid_offsets = filter_result.valid_offsets.data();
|
||||
}
|
||||
auto ret = fill_with_empty(field_id, count, valid_count, valid_data);
|
||||
|
||||
if (!field_meta.is_vector() && column->IsNullable()) {
|
||||
auto dst = ret->mutable_valid_data()->mutable_data();
|
||||
column->BulkIsValid(
|
||||
op_ctx,
|
||||
@ -1691,6 +1776,7 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
seg_offsets,
|
||||
count);
|
||||
}
|
||||
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::VARCHAR:
|
||||
case DataType::STRING:
|
||||
@ -1827,8 +1913,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
bulk_subscript_impl(op_ctx,
|
||||
field_meta.get_sizeof(),
|
||||
column.get(),
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
ret->mutable_vectors()
|
||||
->mutable_float_vector()
|
||||
->mutable_data()
|
||||
@ -1840,8 +1926,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
op_ctx,
|
||||
field_meta.get_sizeof(),
|
||||
column.get(),
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
ret->mutable_vectors()->mutable_float16_vector()->data());
|
||||
break;
|
||||
}
|
||||
@ -1850,8 +1936,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
op_ctx,
|
||||
field_meta.get_sizeof(),
|
||||
column.get(),
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
ret->mutable_vectors()->mutable_bfloat16_vector()->data());
|
||||
break;
|
||||
}
|
||||
@ -1860,8 +1946,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
op_ctx,
|
||||
field_meta.get_sizeof(),
|
||||
column.get(),
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
ret->mutable_vectors()->mutable_binary_vector()->data());
|
||||
break;
|
||||
}
|
||||
@ -1870,8 +1956,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
op_ctx,
|
||||
field_meta.get_sizeof(),
|
||||
column.get(),
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
ret->mutable_vectors()->mutable_int8_vector()->data());
|
||||
break;
|
||||
}
|
||||
@ -1881,7 +1967,7 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
column->BulkValueAt(
|
||||
op_ctx,
|
||||
[&](const char* value, size_t i) mutable {
|
||||
auto offset = seg_offsets[i];
|
||||
auto offset = valid_offsets[i];
|
||||
auto row =
|
||||
offset != INVALID_SEG_OFFSET
|
||||
? static_cast<const knowhere::sparse::SparseRow<
|
||||
@ -1895,8 +1981,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
max_dim = std::max(max_dim, row->dim());
|
||||
dst->add_contents(row->data(), row->data_byte_size());
|
||||
},
|
||||
seg_offsets,
|
||||
count);
|
||||
valid_offsets,
|
||||
valid_count);
|
||||
dst->set_dim(max_dim);
|
||||
ret->mutable_vectors()->set_dim(dst->dim());
|
||||
break;
|
||||
@ -1905,8 +1991,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
|
||||
bulk_subscript_vector_array_impl(
|
||||
op_ctx,
|
||||
column.get(),
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
ret->mutable_vectors()->mutable_vector_array()->mutable_data());
|
||||
break;
|
||||
}
|
||||
@ -2267,7 +2353,12 @@ ChunkedSegmentSealedImpl::generate_interim_index(const FieldId field_id,
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
int64_t row_count = num_rows;
|
||||
std::shared_ptr<ChunkedColumnInterface> vec_data = get_column(field_id);
|
||||
AssertInfo(
|
||||
vec_data != nullptr, "vector field {} not loaded", field_id.get());
|
||||
int64_t row_count = field_meta.is_nullable()
|
||||
? vec_data->GetOffsetMapping().GetValidCount()
|
||||
: num_rows;
|
||||
|
||||
// generate index params
|
||||
auto field_binlog_config = std::unique_ptr<VecIndexConfig>(
|
||||
@ -2279,9 +2370,6 @@ ChunkedSegmentSealedImpl::generate_interim_index(const FieldId field_id,
|
||||
if (row_count < field_binlog_config->GetBuildThreshold()) {
|
||||
return false;
|
||||
}
|
||||
std::shared_ptr<ChunkedColumnInterface> vec_data = get_column(field_id);
|
||||
AssertInfo(
|
||||
vec_data != nullptr, "vector field {} not loaded", field_id.get());
|
||||
auto dim = is_sparse ? std::numeric_limits<uint32_t>::max()
|
||||
: field_meta.get_dim();
|
||||
auto interim_index_type = field_binlog_config->GetIndexType();
|
||||
@ -2397,6 +2485,10 @@ ChunkedSegmentSealedImpl::load_field_data_common(
|
||||
return;
|
||||
}
|
||||
|
||||
if (column->IsNullable() && IsVectorDataType(data_type)) {
|
||||
column->BuildValidRowIds(nullptr);
|
||||
}
|
||||
|
||||
if (!enable_mmap) {
|
||||
if (!is_proxy_column ||
|
||||
is_proxy_column &&
|
||||
@ -2517,12 +2609,7 @@ ChunkedSegmentSealedImpl::Reopen(SchemaPtr sch) {
|
||||
|
||||
auto absent_fields = sch->AbsentFields(*schema_);
|
||||
for (const auto& field_meta : *absent_fields) {
|
||||
// vector field is not supported to be "added field", thus if a vector
|
||||
// field is absent, it means for some reason we want to skip loading this
|
||||
// field.
|
||||
if (!IsVectorDataType(field_meta.get_data_type())) {
|
||||
fill_empty_field(field_meta);
|
||||
}
|
||||
fill_empty_field(field_meta);
|
||||
}
|
||||
|
||||
schema_ = sch;
|
||||
@ -2597,10 +2684,6 @@ ChunkedSegmentSealedImpl::FinishLoad() {
|
||||
// no filling fields that index already loaded and has raw data
|
||||
continue;
|
||||
}
|
||||
if (IsVectorDataType(field_meta.get_data_type())) {
|
||||
// no filling vector fields
|
||||
continue;
|
||||
}
|
||||
fill_empty_field(field_meta);
|
||||
}
|
||||
}
|
||||
@ -2608,10 +2691,11 @@ ChunkedSegmentSealedImpl::FinishLoad() {
|
||||
void
|
||||
ChunkedSegmentSealedImpl::fill_empty_field(const FieldMeta& field_meta) {
|
||||
auto field_id = field_meta.get_id();
|
||||
auto data_type = field_meta.get_data_type();
|
||||
LOG_INFO(
|
||||
"start fill empty field {} (data type {}) for sealed segment "
|
||||
"{}",
|
||||
field_meta.get_data_type(),
|
||||
data_type,
|
||||
field_id.get(),
|
||||
id_);
|
||||
int64_t size = num_rows_.value();
|
||||
@ -2620,40 +2704,11 @@ ChunkedSegmentSealedImpl::fill_empty_field(const FieldMeta& field_meta) {
|
||||
std::unique_ptr<Translator<milvus::Chunk>> translator =
|
||||
std::make_unique<storagev1translator::DefaultValueChunkTranslator>(
|
||||
get_segment_id(), field_meta, field_data_info, false);
|
||||
std::shared_ptr<milvus::ChunkedColumnBase> column{};
|
||||
switch (field_meta.get_data_type()) {
|
||||
case milvus::DataType::STRING:
|
||||
case milvus::DataType::VARCHAR:
|
||||
case milvus::DataType::TEXT: {
|
||||
column = std::make_shared<ChunkedVariableColumn<std::string>>(
|
||||
std::move(translator), field_meta);
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::JSON: {
|
||||
column = std::make_shared<ChunkedVariableColumn<milvus::Json>>(
|
||||
std::move(translator), field_meta);
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::GEOMETRY: {
|
||||
column = std::make_shared<ChunkedVariableColumn<std::string>>(
|
||||
std::move(translator), field_meta);
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::ARRAY: {
|
||||
column = std::make_shared<ChunkedArrayColumn>(std::move(translator),
|
||||
field_meta);
|
||||
break;
|
||||
}
|
||||
case milvus::DataType::VECTOR_ARRAY: {
|
||||
column = std::make_shared<ChunkedVectorArrayColumn>(
|
||||
std::move(translator), field_meta);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
column = std::make_shared<ChunkedColumn>(std::move(translator),
|
||||
field_meta);
|
||||
break;
|
||||
}
|
||||
auto column =
|
||||
MakeChunkedColumnBase(data_type, std::move(translator), field_meta);
|
||||
|
||||
if (column->IsNullable() && IsVectorDataType(data_type)) {
|
||||
column->BuildValidRowIds(nullptr);
|
||||
}
|
||||
|
||||
fields_.wlock()->emplace(field_id, column);
|
||||
|
||||
@ -880,7 +880,10 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
|
||||
google::protobuf::RepeatedPtrField<T>* dst);
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
fill_with_empty(FieldId field_id, int64_t count) const;
|
||||
fill_with_empty(FieldId field_id,
|
||||
int64_t count,
|
||||
int64_t valid_count = 0,
|
||||
const void* valid_data = nullptr) const;
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
get_raw_data(milvus::OpContext* op_ctx,
|
||||
@ -889,6 +892,18 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const;
|
||||
|
||||
struct ValidResult {
|
||||
int64_t valid_count = 0;
|
||||
std::unique_ptr<bool[]> valid_data;
|
||||
std::vector<int64_t> valid_offsets;
|
||||
};
|
||||
|
||||
ValidResult
|
||||
FilterVectorValidOffsets(milvus::OpContext* op_ctx,
|
||||
FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const;
|
||||
|
||||
void
|
||||
update_row_count(int64_t row_count) {
|
||||
num_rows_ = row_count;
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <string>
|
||||
@ -29,6 +30,7 @@
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/FieldData.h"
|
||||
#include "common/Json.h"
|
||||
#include "common/OffsetMapping.h"
|
||||
#include "common/Span.h"
|
||||
#include "common/Types.h"
|
||||
#include "common/Utils.h"
|
||||
@ -79,19 +81,30 @@ class ThreadSafeValidData {
|
||||
}
|
||||
|
||||
bool
|
||||
is_valid(size_t offset) {
|
||||
is_valid(size_t offset) const {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
Assert(offset < length_);
|
||||
AssertInfo(offset < length_,
|
||||
"offset out of range, offset={}, length_={}",
|
||||
offset,
|
||||
length_);
|
||||
return data_[offset];
|
||||
}
|
||||
|
||||
bool*
|
||||
get_chunk_data(size_t offset) {
|
||||
std::shared_lock<std::shared_mutex> lck(mutex_);
|
||||
Assert(offset < length_);
|
||||
AssertInfo(offset < length_,
|
||||
"offset out of range, offset={}, length_={}",
|
||||
offset,
|
||||
length_);
|
||||
return &data_[offset];
|
||||
}
|
||||
|
||||
const FixedVector<bool>&
|
||||
get_data() const {
|
||||
return data_;
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::shared_mutex mutex_{};
|
||||
FixedVector<bool> data_;
|
||||
@ -155,6 +168,40 @@ class VectorBase {
|
||||
virtual void
|
||||
clear() = 0;
|
||||
|
||||
virtual bool
|
||||
is_mapping_storage() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get physical offset from logical offset. Returns -1 if not found.
|
||||
virtual int64_t
|
||||
get_physical_offset(int64_t logical_offset) const {
|
||||
return logical_offset; // default: no mapping
|
||||
}
|
||||
|
||||
// Get logical offset from physical offset. Returns -1 if not found.
|
||||
virtual int64_t
|
||||
get_logical_offset(int64_t physical_offset) const {
|
||||
return physical_offset; // default: no mapping
|
||||
}
|
||||
|
||||
virtual int64_t
|
||||
get_valid_count() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual const FixedVector<bool>&
|
||||
get_valid_data() const {
|
||||
static const FixedVector<bool> empty;
|
||||
return empty;
|
||||
}
|
||||
|
||||
virtual const OffsetMapping&
|
||||
get_offset_mapping() const {
|
||||
static const OffsetMapping empty;
|
||||
return empty;
|
||||
}
|
||||
|
||||
protected:
|
||||
const int64_t size_per_chunk_;
|
||||
};
|
||||
@ -191,10 +238,12 @@ class ConcurrentVectorImpl : public VectorBase {
|
||||
ssize_t elements_per_row,
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: VectorBase(size_per_chunk),
|
||||
elements_per_row_(is_type_entire_row ? 1 : elements_per_row),
|
||||
valid_data_ptr_(valid_data_ptr) {
|
||||
valid_data_ptr_(valid_data_ptr),
|
||||
use_mapping_storage_(use_mapping_storage) {
|
||||
chunks_ptr_ = SelectChunkVectorPtr<Type>(mmap_descriptor);
|
||||
}
|
||||
|
||||
@ -221,19 +270,7 @@ class ConcurrentVectorImpl : public VectorBase {
|
||||
void
|
||||
fill_chunk_data(const std::vector<FieldDataPtr>& datas) override {
|
||||
AssertInfo(chunks_ptr_->size() == 0, "non empty concurrent vector");
|
||||
|
||||
int64_t element_count = 0;
|
||||
for (auto& field_data : datas) {
|
||||
element_count += field_data->get_num_rows();
|
||||
}
|
||||
chunks_ptr_->emplace_to_at_least(1, elements_per_row_ * element_count);
|
||||
int64_t offset = 0;
|
||||
for (auto& field_data : datas) {
|
||||
auto num_rows = field_data->get_num_rows();
|
||||
set_data(
|
||||
offset, static_cast<const Type*>(field_data->Data()), num_rows);
|
||||
offset += num_rows;
|
||||
}
|
||||
set_data_raw(0, datas);
|
||||
}
|
||||
|
||||
void
|
||||
@ -250,16 +287,39 @@ class ConcurrentVectorImpl : public VectorBase {
|
||||
set_data_raw(ssize_t element_offset,
|
||||
const void* source,
|
||||
ssize_t element_count) override {
|
||||
if (element_count == 0) {
|
||||
return;
|
||||
ssize_t valid_count = 0;
|
||||
ssize_t storage_offset = 0;
|
||||
if (use_mapping_storage_) {
|
||||
if constexpr (!std::is_same_v<Type, bool>) {
|
||||
storage_offset = offset_mapping_.GetNextPhysicalOffset();
|
||||
// Build valid_data array for offset mapping
|
||||
std::unique_ptr<bool[]> valid_data(new bool[element_count]);
|
||||
for (ssize_t i = 0; i < element_count; ++i) {
|
||||
bool is_valid =
|
||||
valid_data_ptr_->is_valid(element_offset + i);
|
||||
valid_data[i] = is_valid;
|
||||
if (is_valid) {
|
||||
valid_count++;
|
||||
}
|
||||
}
|
||||
offset_mapping_.BuildIncremental(valid_data.get(),
|
||||
element_count,
|
||||
element_offset,
|
||||
storage_offset);
|
||||
}
|
||||
} else {
|
||||
valid_count = element_count;
|
||||
storage_offset = element_offset;
|
||||
}
|
||||
if (valid_count > 0) {
|
||||
auto size = size_per_chunk_ == MAX_ROW_COUNT ? valid_count
|
||||
: size_per_chunk_;
|
||||
chunks_ptr_->emplace_to_at_least(
|
||||
upper_div(storage_offset + valid_count, size),
|
||||
elements_per_row_ * size);
|
||||
set_data(
|
||||
storage_offset, static_cast<const Type*>(source), valid_count);
|
||||
}
|
||||
auto size =
|
||||
size_per_chunk_ == MAX_ROW_COUNT ? element_count : size_per_chunk_;
|
||||
chunks_ptr_->emplace_to_at_least(
|
||||
upper_div(element_offset + element_count, size),
|
||||
elements_per_row_ * size);
|
||||
set_data(
|
||||
element_offset, static_cast<const Type*>(source), element_count);
|
||||
}
|
||||
|
||||
const void*
|
||||
@ -297,8 +357,24 @@ class ConcurrentVectorImpl : public VectorBase {
|
||||
// just for fun, don't use it directly
|
||||
const Type*
|
||||
get_element(ssize_t element_index) const {
|
||||
auto chunk_id = element_index / size_per_chunk_;
|
||||
auto chunk_offset = element_index % size_per_chunk_;
|
||||
auto physical_index = offset_mapping_.GetPhysicalOffset(element_index);
|
||||
if (physical_index == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
auto chunk_id = physical_index / size_per_chunk_;
|
||||
auto chunk_offset = physical_index % size_per_chunk_;
|
||||
auto data =
|
||||
static_cast<const Type*>(chunks_ptr_->get_chunk_data(chunk_id));
|
||||
return data + chunk_offset * elements_per_row_;
|
||||
}
|
||||
|
||||
const Type*
|
||||
get_physical_element(ssize_t physical_index) const {
|
||||
if (physical_index == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
auto chunk_id = physical_index / size_per_chunk_;
|
||||
auto chunk_offset = physical_index % size_per_chunk_;
|
||||
auto data =
|
||||
static_cast<const Type*>(chunks_ptr_->get_chunk_data(chunk_id));
|
||||
return data + chunk_offset * elements_per_row_;
|
||||
@ -344,6 +420,40 @@ class ConcurrentVectorImpl : public VectorBase {
|
||||
return chunks_ptr_->is_mmap();
|
||||
}
|
||||
|
||||
bool
|
||||
is_mapping_storage() const override {
|
||||
return use_mapping_storage_;
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_physical_offset(int64_t logical_offset) const override {
|
||||
return offset_mapping_.GetPhysicalOffset(logical_offset);
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_logical_offset(int64_t physical_offset) const override {
|
||||
return offset_mapping_.GetLogicalOffset(physical_offset);
|
||||
}
|
||||
|
||||
int64_t
|
||||
get_valid_count() const override {
|
||||
return offset_mapping_.GetValidCount();
|
||||
}
|
||||
|
||||
const milvus::OffsetMapping&
|
||||
get_offset_mapping() const override {
|
||||
return offset_mapping_;
|
||||
}
|
||||
|
||||
const FixedVector<bool>&
|
||||
get_valid_data() const override {
|
||||
if (valid_data_ptr_ != nullptr) {
|
||||
return valid_data_ptr_->get_data();
|
||||
}
|
||||
static const FixedVector<bool> empty;
|
||||
return empty;
|
||||
}
|
||||
|
||||
private:
|
||||
void
|
||||
set_data(ssize_t element_offset,
|
||||
@ -395,9 +505,10 @@ class ConcurrentVectorImpl : public VectorBase {
|
||||
fmt::format("chunk_id out of chunk num, chunk_id={}, chunk_num={}",
|
||||
chunk_id,
|
||||
chunk_num));
|
||||
size_t chunk_id_offset = chunk_id * size_per_chunk_ * elements_per_row_;
|
||||
std::optional<CheckDataValid> check_data_valid = std::nullopt;
|
||||
if (valid_data_ptr_ != nullptr) {
|
||||
if (valid_data_ptr_ != nullptr && !use_mapping_storage_) {
|
||||
size_t chunk_id_offset =
|
||||
chunk_id * size_per_chunk_ * elements_per_row_;
|
||||
check_data_valid = [valid_data_ptr = valid_data_ptr_,
|
||||
beg_id = chunk_id_offset](size_t offset) {
|
||||
return valid_data_ptr->is_valid(beg_id + offset);
|
||||
@ -414,6 +525,9 @@ class ConcurrentVectorImpl : public VectorBase {
|
||||
const ssize_t elements_per_row_;
|
||||
ChunkVectorPtr<Type> chunks_ptr_ = nullptr;
|
||||
ThreadSafeValidDataPtr valid_data_ptr_ = nullptr;
|
||||
|
||||
const bool use_mapping_storage_;
|
||||
milvus::OffsetMapping offset_mapping_;
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
@ -496,9 +610,14 @@ class ConcurrentVector<VectorArray>
|
||||
int64_t dim /* not use it*/,
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: ConcurrentVectorImpl<VectorArray, true>::ConcurrentVectorImpl(
|
||||
1, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
|
||||
1,
|
||||
size_per_chunk,
|
||||
std::move(mmap_descriptor),
|
||||
valid_data_ptr,
|
||||
use_mapping_storage) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -510,13 +629,15 @@ class ConcurrentVector<SparseFloatVector>
|
||||
explicit ConcurrentVector(
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: ConcurrentVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
|
||||
true>::ConcurrentVectorImpl(1,
|
||||
size_per_chunk,
|
||||
std::move(
|
||||
mmap_descriptor),
|
||||
valid_data_ptr),
|
||||
valid_data_ptr,
|
||||
use_mapping_storage),
|
||||
dim_(0) {
|
||||
}
|
||||
|
||||
@ -527,7 +648,16 @@ class ConcurrentVector<SparseFloatVector>
|
||||
auto* src =
|
||||
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
|
||||
source);
|
||||
for (int i = 0; i < element_count; ++i) {
|
||||
ssize_t source_count = element_count;
|
||||
if (this->use_mapping_storage_) {
|
||||
source_count = 0;
|
||||
for (ssize_t i = 0; i < element_count; ++i) {
|
||||
if (this->valid_data_ptr_->is_valid(element_offset + i)) {
|
||||
source_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (ssize_t i = 0; i < source_count; ++i) {
|
||||
dim_ = std::max(dim_, src[i].dim());
|
||||
}
|
||||
ConcurrentVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
|
||||
@ -552,9 +682,14 @@ class ConcurrentVector<FloatVector>
|
||||
ConcurrentVector(int64_t dim,
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl(
|
||||
dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
|
||||
dim,
|
||||
size_per_chunk,
|
||||
std::move(mmap_descriptor),
|
||||
valid_data_ptr,
|
||||
use_mapping_storage) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -566,11 +701,13 @@ class ConcurrentVector<BinaryVector>
|
||||
int64_t dim,
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: ConcurrentVectorImpl(dim / 8,
|
||||
size_per_chunk,
|
||||
std::move(mmap_descriptor),
|
||||
valid_data_ptr) {
|
||||
valid_data_ptr,
|
||||
use_mapping_storage) {
|
||||
AssertInfo(dim % 8 == 0,
|
||||
fmt::format("dim is not a multiple of 8, dim={}", dim));
|
||||
}
|
||||
@ -583,9 +720,14 @@ class ConcurrentVector<Float16Vector>
|
||||
ConcurrentVector(int64_t dim,
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: ConcurrentVectorImpl<float16, false>::ConcurrentVectorImpl(
|
||||
dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
|
||||
dim,
|
||||
size_per_chunk,
|
||||
std::move(mmap_descriptor),
|
||||
valid_data_ptr,
|
||||
use_mapping_storage) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -596,9 +738,14 @@ class ConcurrentVector<BFloat16Vector>
|
||||
ConcurrentVector(int64_t dim,
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: ConcurrentVectorImpl<bfloat16, false>::ConcurrentVectorImpl(
|
||||
dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
|
||||
dim,
|
||||
size_per_chunk,
|
||||
std::move(mmap_descriptor),
|
||||
valid_data_ptr,
|
||||
use_mapping_storage) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -608,9 +755,14 @@ class ConcurrentVector<Int8Vector> : public ConcurrentVectorImpl<int8, false> {
|
||||
ConcurrentVector(int64_t dim,
|
||||
int64_t size_per_chunk,
|
||||
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
|
||||
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
|
||||
bool use_mapping_storage = false)
|
||||
: ConcurrentVectorImpl<int8, false>::ConcurrentVectorImpl(
|
||||
dim, size_per_chunk, std::move(mmap_descriptor)) {
|
||||
dim,
|
||||
size_per_chunk,
|
||||
std::move(mmap_descriptor),
|
||||
valid_data_ptr,
|
||||
use_mapping_storage) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -9,8 +9,10 @@
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "common/EasyAssert.h"
|
||||
#include "common/Types.h"
|
||||
@ -19,6 +21,7 @@
|
||||
#include "index/StringIndexMarisa.h"
|
||||
|
||||
#include "common/SystemProperty.h"
|
||||
#include "segcore/ConcurrentVector.h"
|
||||
#include "segcore/FieldIndexing.h"
|
||||
#include "index/VectorMemIndex.h"
|
||||
#include "IndexConfigGenerator.h"
|
||||
@ -29,6 +32,104 @@
|
||||
namespace milvus::segcore {
|
||||
using std::unique_ptr;
|
||||
|
||||
void
|
||||
IndexingRecord::AppendingIndex(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
FieldId fieldId,
|
||||
const DataArray* stream_data,
|
||||
const InsertRecord<false>& record) {
|
||||
if (!is_in(fieldId)) {
|
||||
return;
|
||||
}
|
||||
auto& indexing = field_indexings_.at(fieldId);
|
||||
auto type = indexing->get_data_type();
|
||||
auto field_raw_data = record.get_data_base(fieldId);
|
||||
auto field_meta = schema_.get_fields().at(fieldId);
|
||||
int64_t valid_count = reserved_offset + size;
|
||||
if (field_meta.is_nullable() && field_raw_data->is_mapping_storage()) {
|
||||
valid_count = field_raw_data->get_valid_count();
|
||||
}
|
||||
if (type == DataType::VECTOR_FLOAT &&
|
||||
valid_count >= indexing->get_build_threshold()) {
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset,
|
||||
size,
|
||||
field_raw_data,
|
||||
stream_data->vectors().float_vector().data().data());
|
||||
} else if (type == DataType::VECTOR_FLOAT16 &&
|
||||
valid_count >= indexing->get_build_threshold()) {
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset,
|
||||
size,
|
||||
field_raw_data,
|
||||
stream_data->vectors().float16_vector().data());
|
||||
} else if (type == DataType::VECTOR_BFLOAT16 &&
|
||||
valid_count >= indexing->get_build_threshold()) {
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset,
|
||||
size,
|
||||
field_raw_data,
|
||||
stream_data->vectors().bfloat16_vector().data());
|
||||
} else if (type == DataType::VECTOR_SPARSE_U32_F32 &&
|
||||
valid_count >= indexing->get_build_threshold()) {
|
||||
auto data = SparseBytesToRows(
|
||||
stream_data->vectors().sparse_float_vector().contents());
|
||||
indexing->AppendSegmentIndexSparse(
|
||||
reserved_offset,
|
||||
size,
|
||||
stream_data->vectors().sparse_float_vector().dim(),
|
||||
field_raw_data,
|
||||
data.get());
|
||||
} else if (type == DataType::GEOMETRY) {
|
||||
// For geometry fields, append data incrementally to RTree index
|
||||
indexing->AppendSegmentIndex(
|
||||
reserved_offset, size, field_raw_data, stream_data);
|
||||
}
|
||||
}
|
||||
|
||||
// concurrent, reentrant
|
||||
void
|
||||
IndexingRecord::AppendingIndex(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
FieldId fieldId,
|
||||
const FieldDataPtr data,
|
||||
const InsertRecord<false>& record) {
|
||||
if (!is_in(fieldId)) {
|
||||
return;
|
||||
}
|
||||
auto& indexing = field_indexings_.at(fieldId);
|
||||
auto type = indexing->get_data_type();
|
||||
const void* p = data->Data();
|
||||
auto vec_base = record.get_data_base(fieldId);
|
||||
auto field_meta = schema_.get_fields().at(fieldId);
|
||||
int64_t valid_count = reserved_offset + size;
|
||||
if (field_meta.is_nullable() && vec_base->is_mapping_storage()) {
|
||||
valid_count = vec_base->get_valid_count();
|
||||
}
|
||||
|
||||
if ((type == DataType::VECTOR_FLOAT || type == DataType::VECTOR_FLOAT16 ||
|
||||
type == DataType::VECTOR_BFLOAT16) &&
|
||||
valid_count >= indexing->get_build_threshold()) {
|
||||
auto vec_base = record.get_data_base(fieldId);
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset, size, vec_base, data->Data());
|
||||
} else if (type == DataType::VECTOR_SPARSE_U32_F32 &&
|
||||
valid_count >= indexing->get_build_threshold()) {
|
||||
auto vec_base = record.get_data_base(fieldId);
|
||||
indexing->AppendSegmentIndexSparse(
|
||||
reserved_offset,
|
||||
size,
|
||||
std::dynamic_pointer_cast<const FieldData<SparseFloatVector>>(data)
|
||||
->Dim(),
|
||||
vec_base,
|
||||
p);
|
||||
} else if (type == DataType::GEOMETRY) {
|
||||
// For geometry fields, append data incrementally to RTree index
|
||||
auto vec_base = record.get_data_base(fieldId);
|
||||
indexing->AppendSegmentIndex(reserved_offset, size, vec_base, data);
|
||||
}
|
||||
}
|
||||
|
||||
VectorFieldIndexing::VectorFieldIndexing(const FieldMeta& field_meta,
|
||||
const FieldIndexMeta& field_index_meta,
|
||||
int64_t segment_max_row_count,
|
||||
@ -140,54 +241,133 @@ VectorFieldIndexing::AppendSegmentIndexSparse(int64_t reserved_offset,
|
||||
int64_t new_data_dim,
|
||||
const VectorBase* field_raw_data,
|
||||
const void* data_source) {
|
||||
using value_type = knowhere::sparse::SparseRow<SparseValueType>;
|
||||
AssertInfo(get_data_type() == DataType::VECTOR_SPARSE_U32_F32,
|
||||
"Data type of vector field is not VECTOR_SPARSE_U32_F32");
|
||||
|
||||
auto conf = get_build_params(get_data_type());
|
||||
auto source = dynamic_cast<const ConcurrentVector<SparseFloatVector>*>(
|
||||
field_raw_data);
|
||||
AssertInfo(source,
|
||||
auto field_source =
|
||||
dynamic_cast<const ConcurrentVector<SparseFloatVector>*>(
|
||||
field_raw_data);
|
||||
AssertInfo(field_source,
|
||||
"field_raw_data can't cast to "
|
||||
"ConcurrentVector<SparseFloatVector> type");
|
||||
AssertInfo(size > 0, "append 0 sparse rows to index is not allowed");
|
||||
if (!built_) {
|
||||
AssertInfo(!sync_with_index_, "index marked synced before built");
|
||||
idx_t total_rows = reserved_offset + size;
|
||||
idx_t chunk_id = 0;
|
||||
auto dim = source->Dim();
|
||||
auto source = static_cast<const value_type*>(data_source);
|
||||
|
||||
while (total_rows > 0) {
|
||||
auto mat = static_cast<
|
||||
const knowhere::sparse::SparseRow<SparseValueType>*>(
|
||||
source->get_chunk_data(chunk_id));
|
||||
auto rows = std::min(source->get_size_per_chunk(), total_rows);
|
||||
auto dataset = knowhere::GenDataSet(rows, dim, mat);
|
||||
dataset->SetIsSparse(true);
|
||||
try {
|
||||
if (chunk_id == 0) {
|
||||
index_->BuildWithDataset(dataset, conf);
|
||||
} else {
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
auto dim = new_data_dim;
|
||||
auto size_per_chunk = field_raw_data->get_size_per_chunk();
|
||||
auto build_threshold = get_build_threshold();
|
||||
bool is_mapping_storage = field_raw_data->is_mapping_storage();
|
||||
auto& valid_data = field_raw_data->get_valid_data();
|
||||
|
||||
if (!built_) {
|
||||
const void* data_ptr = nullptr;
|
||||
std::vector<value_type> data_buf;
|
||||
|
||||
int64_t start_chunk = 0;
|
||||
int64_t end_chunk = (build_threshold - 1) / size_per_chunk;
|
||||
|
||||
if (start_chunk == end_chunk) {
|
||||
data_ptr = field_raw_data->get_chunk_data(start_chunk);
|
||||
} else {
|
||||
data_buf.resize(build_threshold);
|
||||
int64_t actual_copy_count = 0;
|
||||
for (int64_t chunk_id = start_chunk; chunk_id <= end_chunk;
|
||||
++chunk_id) {
|
||||
int64_t copy_start =
|
||||
std::max((int64_t)0, chunk_id * size_per_chunk);
|
||||
int64_t copy_end =
|
||||
std::min(build_threshold, (chunk_id + 1) * size_per_chunk);
|
||||
int64_t copy_count = copy_end - copy_start;
|
||||
// For mapping storage, chunk data is already compactly stored,
|
||||
// so we can copy directly from chunk
|
||||
auto chunk_data = static_cast<const value_type*>(
|
||||
field_raw_data->get_chunk_data(chunk_id));
|
||||
int64_t chunk_offset = copy_start - chunk_id * size_per_chunk;
|
||||
for (int64_t i = 0; i < copy_count; ++i) {
|
||||
data_buf[actual_copy_count + i] =
|
||||
chunk_data[chunk_offset + i];
|
||||
}
|
||||
} catch (SegcoreError& error) {
|
||||
LOG_ERROR("growing sparse index build error: {}", error.what());
|
||||
recreate_index(get_data_type(), nullptr);
|
||||
index_cur_ = 0;
|
||||
return;
|
||||
actual_copy_count += copy_count;
|
||||
}
|
||||
index_cur_.fetch_add(rows);
|
||||
total_rows -= rows;
|
||||
chunk_id++;
|
||||
data_ptr = data_buf.data();
|
||||
}
|
||||
|
||||
auto dataset = knowhere::GenDataSet(build_threshold, dim, data_ptr);
|
||||
dataset->SetIsSparse(true);
|
||||
try {
|
||||
index_->BuildWithDataset(dataset, conf);
|
||||
if (is_mapping_storage) {
|
||||
auto logical_offset =
|
||||
field_raw_data->get_logical_offset(build_threshold - 1);
|
||||
auto update_count = logical_offset + 1;
|
||||
index_->UpdateValidData(valid_data.data(), update_count);
|
||||
}
|
||||
built_ = true;
|
||||
index_cur_.fetch_add(build_threshold);
|
||||
} catch (SegcoreError& error) {
|
||||
LOG_ERROR("growing sparse index build error: {}", error.what());
|
||||
recreate_index(get_data_type(), field_raw_data);
|
||||
return;
|
||||
}
|
||||
built_ = true;
|
||||
sync_with_index_ = true;
|
||||
// if not built_, new rows in data_source have already been added to
|
||||
// source(ConcurrentVector<SparseFloatVector>) and thus added to the
|
||||
// index, thus no need to add again.
|
||||
return;
|
||||
}
|
||||
|
||||
auto dataset = knowhere::GenDataSet(size, new_data_dim, data_source);
|
||||
dataset->SetIsSparse(true);
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
index_cur_.fetch_add(size);
|
||||
// Append rest data when index has been built
|
||||
int64_t add_count = 0;
|
||||
int64_t total_count = 0;
|
||||
if (valid_data.empty()) {
|
||||
// Non-nullable case: add all rows
|
||||
add_count = reserved_offset + size - index_cur_.load();
|
||||
total_count = size;
|
||||
if (add_count <= 0) {
|
||||
sync_with_index_.store(true);
|
||||
return;
|
||||
}
|
||||
auto data_ptr = source + (total_count - add_count);
|
||||
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
|
||||
dataset->SetIsSparse(true);
|
||||
try {
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
index_cur_.fetch_add(add_count);
|
||||
sync_with_index_.store(true);
|
||||
} catch (SegcoreError& error) {
|
||||
LOG_ERROR("growing sparse index add error: {}", error.what());
|
||||
recreate_index(get_data_type(), field_raw_data);
|
||||
}
|
||||
} else {
|
||||
// Nullable case: only add valid rows (matching dense vector approach)
|
||||
auto index_total_count = index_->GetOffsetMapping().GetTotalCount();
|
||||
auto add_valid_data_count = reserved_offset + size - index_total_count;
|
||||
for (auto i = reserved_offset; i < reserved_offset + size; i++) {
|
||||
if (valid_data[i]) {
|
||||
total_count++;
|
||||
if (i >= index_total_count) {
|
||||
add_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (add_count <= 0 && add_valid_data_count <= 0) {
|
||||
sync_with_index_.store(true);
|
||||
return;
|
||||
}
|
||||
if (add_count > 0) {
|
||||
auto data_ptr = source + (total_count - add_count);
|
||||
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
|
||||
dataset->SetIsSparse(true);
|
||||
try {
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
} catch (SegcoreError& error) {
|
||||
LOG_ERROR("growing sparse index add error: {}", error.what());
|
||||
recreate_index(get_data_type(), field_raw_data);
|
||||
}
|
||||
}
|
||||
if (add_valid_data_count > 0) {
|
||||
index_->UpdateValidData(valid_data.data() + index_total_count,
|
||||
add_valid_data_count);
|
||||
}
|
||||
index_cur_.fetch_add(add_count);
|
||||
sync_with_index_.store(true);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
@ -203,8 +383,10 @@ VectorFieldIndexing::AppendSegmentIndexDense(int64_t reserved_offset,
|
||||
auto dim = get_dim();
|
||||
auto conf = get_build_params(get_data_type());
|
||||
auto size_per_chunk = field_raw_data->get_size_per_chunk();
|
||||
//append vector [vector_id_beg, vector_id_end] into index
|
||||
//build index [vector_id_beg, build_threshold) when index not exist
|
||||
auto build_threshold = get_build_threshold();
|
||||
bool is_mapping_storage = field_raw_data->is_mapping_storage();
|
||||
auto& valid_data = field_raw_data->get_valid_data();
|
||||
|
||||
AssertInfo(ConcurrentDenseVectorCheck(field_raw_data, get_data_type()),
|
||||
"vec_base can't cast to ConcurrentVector type");
|
||||
size_t vec_length;
|
||||
@ -216,88 +398,112 @@ VectorFieldIndexing::AppendSegmentIndexDense(int64_t reserved_offset,
|
||||
vec_length = dim * sizeof(bfloat16);
|
||||
}
|
||||
if (!built_) {
|
||||
idx_t vector_id_beg = index_cur_.load();
|
||||
Assert(vector_id_beg == 0);
|
||||
idx_t vector_id_end = get_build_threshold() - 1;
|
||||
auto chunk_id_beg = vector_id_beg / size_per_chunk;
|
||||
auto chunk_id_end = vector_id_end / size_per_chunk;
|
||||
const void* data_ptr;
|
||||
std::unique_ptr<char[]> data_buf;
|
||||
// Chunk data stores valid vectors compactly for both nullable and non-nullable
|
||||
int64_t start_chunk = 0;
|
||||
int64_t end_chunk = (build_threshold - 1) / size_per_chunk;
|
||||
|
||||
int64_t vec_num = vector_id_end - vector_id_beg + 1;
|
||||
// for train index
|
||||
const void* data_addr;
|
||||
unique_ptr<char[]> vec_data;
|
||||
//all train data in one chunk
|
||||
if (chunk_id_beg == chunk_id_end) {
|
||||
data_addr = field_raw_data->get_chunk_data(chunk_id_beg);
|
||||
if (start_chunk == end_chunk) {
|
||||
auto chunk_data = static_cast<const char*>(
|
||||
field_raw_data->get_chunk_data(start_chunk));
|
||||
data_ptr = chunk_data;
|
||||
} else {
|
||||
//merge data from multiple chunks together
|
||||
vec_data = std::make_unique<char[]>(vec_num * vec_length);
|
||||
int64_t offset = 0;
|
||||
//copy vector data [vector_id_beg, vector_id_end]
|
||||
for (int chunk_id = chunk_id_beg; chunk_id <= chunk_id_end;
|
||||
chunk_id++) {
|
||||
int chunk_offset = 0;
|
||||
int chunk_copysz =
|
||||
chunk_id == chunk_id_end
|
||||
? vector_id_end - chunk_id * size_per_chunk + 1
|
||||
: size_per_chunk;
|
||||
std::memcpy(
|
||||
(void*)((const char*)vec_data.get() + offset * vec_length),
|
||||
(void*)((const char*)field_raw_data->get_chunk_data(
|
||||
chunk_id) +
|
||||
chunk_offset * vec_length),
|
||||
chunk_copysz * vec_length);
|
||||
offset += chunk_copysz;
|
||||
data_buf = std::make_unique<char[]>(build_threshold * vec_length);
|
||||
int64_t actual_copy_count = 0;
|
||||
for (int64_t chunk_id = start_chunk; chunk_id <= end_chunk;
|
||||
++chunk_id) {
|
||||
auto chunk_data = static_cast<const char*>(
|
||||
field_raw_data->get_chunk_data(chunk_id));
|
||||
int64_t copy_start =
|
||||
std::max((int64_t)0, chunk_id * size_per_chunk);
|
||||
int64_t copy_end =
|
||||
std::min(build_threshold, (chunk_id + 1) * size_per_chunk);
|
||||
int64_t copy_count = copy_end - copy_start;
|
||||
auto src =
|
||||
chunk_data +
|
||||
(copy_start - chunk_id * size_per_chunk) * vec_length;
|
||||
std::memcpy(data_buf.get() + actual_copy_count * vec_length,
|
||||
src,
|
||||
copy_count * vec_length);
|
||||
actual_copy_count += copy_count;
|
||||
}
|
||||
data_addr = vec_data.get();
|
||||
data_ptr = data_buf.get();
|
||||
}
|
||||
auto dataset = knowhere::GenDataSet(vec_num, dim, data_addr);
|
||||
dataset->SetIsOwner(false);
|
||||
|
||||
auto dataset = knowhere::GenDataSet(build_threshold, dim, data_ptr);
|
||||
try {
|
||||
index_->BuildWithDataset(dataset, conf);
|
||||
if (is_mapping_storage) {
|
||||
auto logical_offset =
|
||||
field_raw_data->get_logical_offset(build_threshold - 1);
|
||||
auto update_count = logical_offset + 1;
|
||||
index_->UpdateValidData(valid_data.data(), update_count);
|
||||
}
|
||||
built_ = true;
|
||||
index_cur_.fetch_add(build_threshold);
|
||||
} catch (SegcoreError& error) {
|
||||
LOG_ERROR("growing index build error: {}", error.what());
|
||||
recreate_index(get_data_type(), field_raw_data);
|
||||
return;
|
||||
}
|
||||
index_cur_.fetch_add(vec_num);
|
||||
built_ = true;
|
||||
}
|
||||
//append rest data when index has built
|
||||
idx_t vector_id_beg = index_cur_.load();
|
||||
idx_t vector_id_end = reserved_offset + size - 1;
|
||||
auto chunk_id_beg = vector_id_beg / size_per_chunk;
|
||||
auto chunk_id_end = vector_id_end / size_per_chunk;
|
||||
int64_t vec_num = vector_id_end - vector_id_beg + 1;
|
||||
|
||||
if (vec_num <= 0) {
|
||||
sync_with_index_.store(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (sync_with_index_.load()) {
|
||||
Assert(size == vec_num);
|
||||
auto dataset = knowhere::GenDataSet(vec_num, dim, data_source);
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
index_cur_.fetch_add(vec_num);
|
||||
} else {
|
||||
for (int chunk_id = chunk_id_beg; chunk_id <= chunk_id_end;
|
||||
chunk_id++) {
|
||||
int chunk_offset = chunk_id == chunk_id_beg
|
||||
? index_cur_ - chunk_id * size_per_chunk
|
||||
: 0;
|
||||
int chunk_sz =
|
||||
chunk_id == chunk_id_end
|
||||
? vector_id_end % size_per_chunk - chunk_offset + 1
|
||||
: size_per_chunk - chunk_offset;
|
||||
auto dataset = knowhere::GenDataSet(
|
||||
chunk_sz,
|
||||
dim,
|
||||
(const char*)field_raw_data->get_chunk_data(chunk_id) +
|
||||
chunk_offset * vec_length);
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
index_cur_.fetch_add(chunk_sz);
|
||||
int64_t add_count = 0;
|
||||
int64_t total_count = 0;
|
||||
if (valid_data.empty()) {
|
||||
add_count = reserved_offset + size - index_cur_.load();
|
||||
total_count = size;
|
||||
if (add_count <= 0) {
|
||||
sync_with_index_.store(true);
|
||||
return;
|
||||
}
|
||||
auto data_ptr = static_cast<const char*>(data_source) +
|
||||
(total_count - add_count) * vec_length;
|
||||
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
|
||||
try {
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
index_cur_.fetch_add(add_count);
|
||||
sync_with_index_.store(true);
|
||||
} catch (SegcoreError& error) {
|
||||
LOG_ERROR("growing index add error: {}", error.what());
|
||||
recreate_index(get_data_type(), field_raw_data);
|
||||
}
|
||||
} else {
|
||||
// Nullable dense vectors: data_source (proto) contains valid vectors compactly
|
||||
auto index_total_count = index_->GetOffsetMapping().GetTotalCount();
|
||||
auto add_valid_data_count = reserved_offset + size - index_total_count;
|
||||
auto index_cur_val = index_cur_.load();
|
||||
// Count valid vectors in this batch range
|
||||
for (auto i = reserved_offset; i < reserved_offset + size; i++) {
|
||||
if (valid_data[i]) {
|
||||
total_count++;
|
||||
if (i >= index_total_count) {
|
||||
add_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (add_count <= 0 && add_valid_data_count <= 0) {
|
||||
sync_with_index_.store(true);
|
||||
return;
|
||||
}
|
||||
if (add_count > 0) {
|
||||
// data_source contains valid vectors compactly, skip already indexed ones
|
||||
auto data_ptr = static_cast<const char*>(data_source) +
|
||||
(total_count - add_count) * vec_length;
|
||||
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
|
||||
try {
|
||||
index_->AddWithDataset(dataset, conf);
|
||||
} catch (SegcoreError& error) {
|
||||
LOG_ERROR("growing index add error: {}", error.what());
|
||||
recreate_index(get_data_type(), field_raw_data);
|
||||
}
|
||||
}
|
||||
if (add_valid_data_count > 0) {
|
||||
index_->UpdateValidData(valid_data.data() + index_total_count,
|
||||
add_valid_data_count);
|
||||
}
|
||||
index_cur_.fetch_add(add_count);
|
||||
sync_with_index_.store(true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -434,93 +434,19 @@ class IndexingRecord {
|
||||
assert(offset_id == schema_.size());
|
||||
}
|
||||
|
||||
// concurrent, reentrant
|
||||
void
|
||||
AppendingIndex(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
FieldId fieldId,
|
||||
const DataArray* stream_data,
|
||||
const InsertRecord<false>& record) {
|
||||
if (!is_in(fieldId)) {
|
||||
return;
|
||||
}
|
||||
auto& indexing = field_indexings_.at(fieldId);
|
||||
auto type = indexing->get_data_type();
|
||||
auto field_raw_data = record.get_data_base(fieldId);
|
||||
if (type == DataType::VECTOR_FLOAT &&
|
||||
reserved_offset + size >= indexing->get_build_threshold()) {
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset,
|
||||
size,
|
||||
field_raw_data,
|
||||
stream_data->vectors().float_vector().data().data());
|
||||
} else if (type == DataType::VECTOR_FLOAT16 &&
|
||||
reserved_offset + size >= indexing->get_build_threshold()) {
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset,
|
||||
size,
|
||||
field_raw_data,
|
||||
stream_data->vectors().float16_vector().data());
|
||||
} else if (type == DataType::VECTOR_BFLOAT16 &&
|
||||
reserved_offset + size >= indexing->get_build_threshold()) {
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset,
|
||||
size,
|
||||
field_raw_data,
|
||||
stream_data->vectors().bfloat16_vector().data());
|
||||
} else if (type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
auto data = SparseBytesToRows(
|
||||
stream_data->vectors().sparse_float_vector().contents());
|
||||
indexing->AppendSegmentIndexSparse(
|
||||
reserved_offset,
|
||||
size,
|
||||
stream_data->vectors().sparse_float_vector().dim(),
|
||||
field_raw_data,
|
||||
data.get());
|
||||
} else if (type == DataType::GEOMETRY) {
|
||||
// For geometry fields, append data incrementally to RTree index
|
||||
indexing->AppendSegmentIndex(
|
||||
reserved_offset, size, field_raw_data, stream_data);
|
||||
}
|
||||
}
|
||||
const InsertRecord<false>& record);
|
||||
|
||||
// concurrent, reentrant
|
||||
void
|
||||
AppendingIndex(int64_t reserved_offset,
|
||||
int64_t size,
|
||||
FieldId fieldId,
|
||||
const FieldDataPtr data,
|
||||
const InsertRecord<false>& record) {
|
||||
if (!is_in(fieldId)) {
|
||||
return;
|
||||
}
|
||||
auto& indexing = field_indexings_.at(fieldId);
|
||||
auto type = indexing->get_data_type();
|
||||
const void* p = data->Data();
|
||||
|
||||
if ((type == DataType::VECTOR_FLOAT ||
|
||||
type == DataType::VECTOR_FLOAT16 ||
|
||||
type == DataType::VECTOR_BFLOAT16) &&
|
||||
reserved_offset + size >= indexing->get_build_threshold()) {
|
||||
auto vec_base = record.get_data_base(fieldId);
|
||||
indexing->AppendSegmentIndexDense(
|
||||
reserved_offset, size, vec_base, data->Data());
|
||||
} else if (type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
auto vec_base = record.get_data_base(fieldId);
|
||||
indexing->AppendSegmentIndexSparse(
|
||||
reserved_offset,
|
||||
size,
|
||||
std::dynamic_pointer_cast<const FieldData<SparseFloatVector>>(
|
||||
data)
|
||||
->Dim(),
|
||||
vec_base,
|
||||
p);
|
||||
} else if (type == DataType::GEOMETRY) {
|
||||
// For geometry fields, append data incrementally to RTree index
|
||||
auto vec_base = record.get_data_base(fieldId);
|
||||
indexing->AppendSegmentIndex(reserved_offset, size, vec_base, data);
|
||||
}
|
||||
}
|
||||
const InsertRecord<false>& record);
|
||||
|
||||
// for sparse float vector:
|
||||
// * element_size is not used
|
||||
|
||||
@ -87,12 +87,6 @@ VecIndexConfig::VecIndexConfig(const int64_t max_index_row_cout,
|
||||
|
||||
int64_t
|
||||
VecIndexConfig::GetBuildThreshold() const noexcept {
|
||||
// For sparse, do not impose a threshold and start using index with any
|
||||
// number of rows. Unlike dense vector index, growing sparse vector index
|
||||
// does not require a minimum number of rows to train.
|
||||
if (is_sparse_) {
|
||||
return 0;
|
||||
}
|
||||
auto ratio = config_.get_build_ratio();
|
||||
assert(ratio >= 0.0 && ratio < 1.0);
|
||||
return std::max(int64_t(max_index_row_count_ * ratio),
|
||||
|
||||
@ -1025,7 +1025,6 @@ class InsertRecordGrowing {
|
||||
}
|
||||
|
||||
// append a column of vector type
|
||||
// vector not support nullable, not pass valid data ptr
|
||||
template <typename VectorType>
|
||||
void
|
||||
append_data(FieldId field_id,
|
||||
@ -1033,9 +1032,15 @@ class InsertRecordGrowing {
|
||||
int64_t size_per_chunk,
|
||||
const storage::MmapChunkDescriptorPtr mmap_descriptor) {
|
||||
static_assert(std::is_base_of_v<VectorTrait, VectorType>);
|
||||
data_.emplace(field_id,
|
||||
std::make_unique<ConcurrentVector<VectorType>>(
|
||||
dim, size_per_chunk, mmap_descriptor));
|
||||
bool use_mapping_storage = is_valid_data_exist(field_id);
|
||||
data_.emplace(
|
||||
field_id,
|
||||
std::make_unique<ConcurrentVector<VectorType>>(
|
||||
dim,
|
||||
size_per_chunk,
|
||||
mmap_descriptor,
|
||||
use_mapping_storage ? get_valid_data(field_id) : nullptr,
|
||||
use_mapping_storage));
|
||||
}
|
||||
|
||||
// append a column of scalar or sparse float vector type
|
||||
@ -1045,13 +1050,23 @@ class InsertRecordGrowing {
|
||||
int64_t size_per_chunk,
|
||||
const storage::MmapChunkDescriptorPtr mmap_descriptor) {
|
||||
static_assert(IsScalar<Type> || IsSparse<Type>);
|
||||
data_.emplace(
|
||||
field_id,
|
||||
std::make_unique<ConcurrentVector<Type>>(
|
||||
size_per_chunk,
|
||||
mmap_descriptor,
|
||||
is_valid_data_exist(field_id) ? get_valid_data(field_id)
|
||||
: nullptr));
|
||||
bool use_mapping_storage = is_valid_data_exist(field_id);
|
||||
if constexpr (IsSparse<Type>) {
|
||||
data_.emplace(
|
||||
field_id,
|
||||
std::make_unique<ConcurrentVector<Type>>(
|
||||
size_per_chunk,
|
||||
mmap_descriptor,
|
||||
use_mapping_storage ? get_valid_data(field_id) : nullptr,
|
||||
use_mapping_storage));
|
||||
} else {
|
||||
data_.emplace(
|
||||
field_id,
|
||||
std::make_unique<ConcurrentVector<Type>>(
|
||||
size_per_chunk,
|
||||
mmap_descriptor,
|
||||
use_mapping_storage ? get_valid_data(field_id) : nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
|
||||
@ -312,13 +312,13 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
|
||||
AssertInfo(field_id_to_offset.count(field_id),
|
||||
fmt::format("can't find field {}", field_id.get()));
|
||||
auto data_offset = field_id_to_offset[field_id];
|
||||
if (field_meta.is_nullable()) {
|
||||
insert_record_.get_valid_data(field_id)->set_data_raw(
|
||||
num_rows,
|
||||
&insert_record_proto->fields_data(data_offset),
|
||||
field_meta);
|
||||
}
|
||||
if (!indexing_record_.HasRawData(field_id)) {
|
||||
if (field_meta.is_nullable()) {
|
||||
insert_record_.get_valid_data(field_id)->set_data_raw(
|
||||
num_rows,
|
||||
&insert_record_proto->fields_data(data_offset),
|
||||
field_meta);
|
||||
}
|
||||
insert_record_.get_data_base(field_id)->set_data_raw(
|
||||
reserved_offset,
|
||||
num_rows,
|
||||
@ -937,14 +937,31 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
auto vec_ptr = insert_record_.get_data_base(field_id);
|
||||
if (field_meta.is_vector()) {
|
||||
auto result = CreateEmptyVectorDataArray(count, field_meta);
|
||||
int64_t valid_count = count;
|
||||
const bool* valid_data = nullptr;
|
||||
const int64_t* valid_offsets = seg_offsets;
|
||||
ValidResult filter_result;
|
||||
|
||||
if (field_meta.is_nullable()) {
|
||||
filter_result =
|
||||
FilterVectorValidOffsets(op_ctx, field_id, seg_offsets, count);
|
||||
valid_count = filter_result.valid_count;
|
||||
valid_data = filter_result.valid_data.get();
|
||||
valid_offsets = filter_result.valid_offsets.data();
|
||||
}
|
||||
|
||||
auto result = CreateEmptyVectorDataArray(
|
||||
count, valid_count, valid_data, field_meta);
|
||||
if (valid_count == 0) {
|
||||
return result;
|
||||
}
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
bulk_subscript_impl<FloatVector>(op_ctx,
|
||||
field_id,
|
||||
field_meta.get_sizeof(),
|
||||
vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
result->mutable_vectors()
|
||||
->mutable_float_vector()
|
||||
->mutable_data()
|
||||
@ -955,8 +972,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
|
||||
field_id,
|
||||
field_meta.get_sizeof(),
|
||||
vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
result->mutable_vectors()->mutable_binary_vector()->data());
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) {
|
||||
bulk_subscript_impl<Float16Vector>(
|
||||
@ -964,8 +981,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
|
||||
field_id,
|
||||
field_meta.get_sizeof(),
|
||||
vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
result->mutable_vectors()->mutable_float16_vector()->data());
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_BFLOAT16) {
|
||||
bulk_subscript_impl<BFloat16Vector>(
|
||||
@ -973,8 +990,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
|
||||
field_id,
|
||||
field_meta.get_sizeof(),
|
||||
vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
result->mutable_vectors()->mutable_bfloat16_vector()->data());
|
||||
} else if (field_meta.get_data_type() ==
|
||||
DataType::VECTOR_SPARSE_U32_F32) {
|
||||
@ -982,8 +999,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
|
||||
op_ctx,
|
||||
field_id,
|
||||
(const ConcurrentVector<SparseFloatVector>*)vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
result->mutable_vectors()->mutable_sparse_float_vector());
|
||||
result->mutable_vectors()->set_dim(
|
||||
result->vectors().sparse_float_vector().dim());
|
||||
@ -993,8 +1010,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
|
||||
field_id,
|
||||
field_meta.get_sizeof(),
|
||||
vec_ptr,
|
||||
seg_offsets,
|
||||
count,
|
||||
valid_offsets,
|
||||
valid_count,
|
||||
result->mutable_vectors()->mutable_int8_vector()->data());
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
bulk_subscript_vector_array_impl(op_ctx,
|
||||
@ -1190,7 +1207,7 @@ SegmentGrowingImpl::bulk_subscript_sparse_float_vector_impl(
|
||||
[&](size_t i) {
|
||||
auto offset = seg_offsets[i];
|
||||
return offset != INVALID_SEG_OFFSET
|
||||
? vec_raw->get_element(offset)
|
||||
? vec_raw->get_physical_element(offset)
|
||||
: nullptr;
|
||||
},
|
||||
count,
|
||||
@ -1257,12 +1274,8 @@ SegmentGrowingImpl::bulk_subscript_impl(milvus::OpContext* op_ctx,
|
||||
for (int i = 0; i < count; ++i) {
|
||||
auto dst = output_base + i * element_sizeof;
|
||||
auto offset = seg_offsets[i];
|
||||
if (offset == INVALID_SEG_OFFSET) {
|
||||
memset(dst, 0, element_sizeof);
|
||||
} else {
|
||||
auto src = (const uint8_t*)vec.get_element(offset);
|
||||
memcpy(dst, src, element_sizeof);
|
||||
}
|
||||
auto src = (const uint8_t*)vec.get_physical_element(offset);
|
||||
memcpy(dst, src, element_sizeof);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -1860,4 +1873,68 @@ SegmentGrowingImpl::BuildGeometryCacheForLoad(
|
||||
}
|
||||
}
|
||||
|
||||
SegmentGrowingImpl::ValidResult
|
||||
SegmentGrowingImpl::FilterVectorValidOffsets(milvus::OpContext* op_ctx,
|
||||
FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const {
|
||||
ValidResult result;
|
||||
result.valid_count = count;
|
||||
|
||||
if (indexing_record_.SyncDataWithIndex(field_id)) {
|
||||
const auto& field_indexing =
|
||||
indexing_record_.get_vec_field_indexing(field_id);
|
||||
auto indexing = field_indexing.get_segment_indexing();
|
||||
auto vec_index = dynamic_cast<index::VectorIndex*>(indexing.get());
|
||||
|
||||
if (vec_index != nullptr && vec_index->HasValidData()) {
|
||||
result.valid_data = std::make_unique<bool[]>(count);
|
||||
result.valid_offsets.reserve(count);
|
||||
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
bool is_valid = vec_index->IsRowValid(seg_offsets[i]);
|
||||
result.valid_data[i] = is_valid;
|
||||
if (is_valid) {
|
||||
int64_t physical_offset =
|
||||
vec_index->GetPhysicalOffset(seg_offsets[i]);
|
||||
if (physical_offset >= 0) {
|
||||
result.valid_offsets.push_back(physical_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
result.valid_count = result.valid_offsets.size();
|
||||
}
|
||||
} else {
|
||||
auto vec_base = insert_record_.get_data_base(field_id);
|
||||
if (vec_base != nullptr) {
|
||||
const auto& valid_data_vec = vec_base->get_valid_data();
|
||||
bool is_mapping_storage = vec_base->is_mapping_storage();
|
||||
if (!valid_data_vec.empty()) {
|
||||
result.valid_data = std::make_unique<bool[]>(count);
|
||||
result.valid_offsets.reserve(count);
|
||||
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
auto offset = seg_offsets[i];
|
||||
bool is_valid =
|
||||
offset >= 0 &&
|
||||
offset < static_cast<int64_t>(valid_data_vec.size()) &&
|
||||
valid_data_vec[offset];
|
||||
result.valid_data[i] = is_valid;
|
||||
if (is_valid) {
|
||||
if (is_mapping_storage) {
|
||||
int64_t physical_offset =
|
||||
vec_base->get_physical_offset(offset);
|
||||
if (physical_offset >= 0) {
|
||||
result.valid_offsets.push_back(physical_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.valid_count = result.valid_offsets.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
||||
@ -504,6 +504,17 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
struct ValidResult {
|
||||
int64_t valid_count = 0;
|
||||
std::unique_ptr<bool[]> valid_data;
|
||||
std::vector<int64_t> valid_offsets;
|
||||
};
|
||||
|
||||
ValidResult
|
||||
FilterVectorValidOffsets(milvus::OpContext* op_ctx,
|
||||
FieldId field_id,
|
||||
const int64_t* seg_offsets,
|
||||
int64_t count) const;
|
||||
|
||||
protected:
|
||||
int64_t
|
||||
|
||||
@ -280,11 +280,10 @@ TEST_P(GrowingIndexTest, Correctness) {
|
||||
auto inserted = (i + 1) * per_batch;
|
||||
// once index built, chunk data will be removed.
|
||||
// growing index will only be built when num rows reached
|
||||
// get_build_threshold(). This value for sparse is 0, thus sparse index
|
||||
// will be built since the first chunk. Dense segment buffers the first
|
||||
// get_build_threshold(). Both sparse and dense segment buffer the first
|
||||
// 2 chunks before building an index in this test case.
|
||||
|
||||
if ((!is_sparse && i < 2) || !intermin_index_with_raw_data) {
|
||||
if (i < 2 || !intermin_index_with_raw_data) {
|
||||
EXPECT_EQ(field_data->num_chunk(),
|
||||
upper_div(inserted, field_data->get_size_per_chunk()));
|
||||
} else {
|
||||
|
||||
@ -12,12 +12,18 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common/Types.h"
|
||||
#include "common/IndexMeta.h"
|
||||
#include "knowhere/comp/index_param.h"
|
||||
#include "segcore/SegmentGrowing.h"
|
||||
#include "segcore/SegmentGrowingImpl.h"
|
||||
#include "pb/schema.pb.h"
|
||||
#include "pb/plan.pb.h"
|
||||
#include "query/Plan.h"
|
||||
#include "expr/ITypeExpr.h"
|
||||
#include "plan/PlanNode.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
#include "test_utils/storage_test_utils.h"
|
||||
#include "test_utils/GenExprProto.h"
|
||||
|
||||
using namespace milvus::segcore;
|
||||
using namespace milvus;
|
||||
@ -435,6 +441,325 @@ TEST(Growing, FillNullableData) {
|
||||
}
|
||||
}
|
||||
|
||||
class GrowingNullableTest : public ::testing::TestWithParam<
|
||||
std::tuple</*data_type*/ DataType,
|
||||
/*metric_type*/ knowhere::MetricType,
|
||||
/*index_type*/ std::string,
|
||||
/*null_percent*/ int,
|
||||
/*enable_interim_index*/ bool>> {
|
||||
public:
|
||||
void
|
||||
SetUp() override {
|
||||
std::tie(data_type,
|
||||
metric_type,
|
||||
index_type,
|
||||
null_percent,
|
||||
enable_interim_index) = GetParam();
|
||||
}
|
||||
|
||||
DataType data_type;
|
||||
knowhere::MetricType metric_type;
|
||||
std::string index_type;
|
||||
int null_percent;
|
||||
bool enable_interim_index;
|
||||
};
|
||||
|
||||
static std::vector<
|
||||
std::tuple<DataType, knowhere::MetricType, std::string, int, bool>>
|
||||
GenerateGrowingNullableTestParams() {
|
||||
std::vector<
|
||||
std::tuple<DataType, knowhere::MetricType, std::string, int, bool>>
|
||||
params;
|
||||
|
||||
// Dense float vectors with IVF_FLAT
|
||||
std::vector<std::tuple<DataType, knowhere::MetricType, std::string>>
|
||||
base_configs = {
|
||||
{DataType::VECTOR_FLOAT,
|
||||
knowhere::metric::L2,
|
||||
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT},
|
||||
{DataType::VECTOR_FLOAT,
|
||||
knowhere::metric::IP,
|
||||
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT},
|
||||
{DataType::VECTOR_FLOAT,
|
||||
knowhere::metric::COSINE,
|
||||
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT},
|
||||
{DataType::VECTOR_SPARSE_U32_F32,
|
||||
knowhere::metric::IP,
|
||||
knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX},
|
||||
};
|
||||
|
||||
std::vector<int> null_percents = {0, 20, 100};
|
||||
|
||||
std::vector<bool> interim_index_configs = {true, false};
|
||||
|
||||
for (const auto& [dtype, metric, idx_type] : base_configs) {
|
||||
for (int null_pct : null_percents) {
|
||||
for (bool enable_interim : interim_index_configs) {
|
||||
params.push_back(
|
||||
{dtype, metric, idx_type, null_pct, enable_interim});
|
||||
}
|
||||
}
|
||||
}
|
||||
return params;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
NullableVectorParameters,
|
||||
GrowingNullableTest,
|
||||
::testing::ValuesIn(GenerateGrowingNullableTestParams()));
|
||||
|
||||
TEST_P(GrowingNullableTest, SearchAndQueryNullableVectors) {
|
||||
using namespace milvus::query;
|
||||
|
||||
bool nullable = true;
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
||||
int64_t dim = 8;
|
||||
auto vec = schema->AddDebugField(
|
||||
"embeddings", data_type, dim, metric_type, nullable);
|
||||
schema->set_primary_field_id(int64_field);
|
||||
|
||||
std::map<std::string, std::string> index_params;
|
||||
std::map<std::string, std::string> type_params;
|
||||
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
index_params = {{"index_type", index_type},
|
||||
{"metric_type", metric_type}};
|
||||
type_params = {};
|
||||
} else {
|
||||
index_params = {{"index_type", index_type},
|
||||
{"metric_type", metric_type},
|
||||
{"nlist", "128"}};
|
||||
type_params = {{"dim", std::to_string(dim)}};
|
||||
}
|
||||
FieldIndexMeta fieldIndexMeta(
|
||||
vec, std::move(index_params), std::move(type_params));
|
||||
auto config = SegcoreConfig::default_config();
|
||||
config.set_chunk_rows(1024);
|
||||
config.set_enable_interim_segment_index(enable_interim_index);
|
||||
// Explicitly set interim index type to avoid contamination from other tests
|
||||
config.set_dense_vector_intermin_index_type(
|
||||
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC);
|
||||
std::map<FieldId, FieldIndexMeta> filedMap = {{vec, fieldIndexMeta}};
|
||||
IndexMetaPtr metaPtr =
|
||||
std::make_shared<CollectionIndexMeta>(100000, std::move(filedMap));
|
||||
auto segment_growing = CreateGrowingSegment(schema, metaPtr, 1, config);
|
||||
auto segment = dynamic_cast<SegmentGrowingImpl*>(segment_growing.get());
|
||||
|
||||
int64_t batch_size = 2000;
|
||||
int64_t num_rounds = 10;
|
||||
int64_t topk = 5;
|
||||
int64_t num_queries = 2;
|
||||
Timestamp timestamp = 10000000;
|
||||
|
||||
// Prepare search plan
|
||||
std::string search_params_fmt;
|
||||
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
search_params_fmt = R"(
|
||||
vector_anns:<
|
||||
field_id: {}
|
||||
query_info:<
|
||||
topk: {}
|
||||
round_decimal: 3
|
||||
metric_type: "{}"
|
||||
search_params: "{{\"drop_ratio_search\": 0.1}}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)";
|
||||
} else {
|
||||
search_params_fmt = R"(
|
||||
vector_anns:<
|
||||
field_id: {}
|
||||
query_info:<
|
||||
topk: {}
|
||||
round_decimal: 3
|
||||
metric_type: "{}"
|
||||
search_params: "{{\"nprobe\": 10}}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>
|
||||
)";
|
||||
}
|
||||
|
||||
auto raw_plan =
|
||||
fmt::format(search_params_fmt, vec.get(), topk, metric_type);
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
||||
|
||||
// Create query vectors
|
||||
proto::common::PlaceholderGroup ph_group_raw;
|
||||
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
ph_group_raw = CreateSparseFloatPlaceholderGroup(num_queries, 42);
|
||||
} else {
|
||||
auto query_data = generate_float_vector(num_queries, dim);
|
||||
ph_group_raw =
|
||||
CreatePlaceholderGroupFromBlob(num_queries, dim, query_data.data());
|
||||
}
|
||||
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
// Store all inserted data for verification
|
||||
// For nullable vectors, data is stored sparsely (only valid vectors)
|
||||
// We need a mapping from logical offset to physical offset
|
||||
std::vector<float> all_float_vectors; // Physical storage (only valid)
|
||||
std::vector<knowhere::sparse::SparseRow<float>> all_sparse_vectors;
|
||||
std::vector<bool> all_valid_data; // Logical storage (all rows)
|
||||
std::vector<int64_t>
|
||||
logical_to_physical; // Maps logical offset to physical
|
||||
|
||||
// Insert data in multiple rounds and test after each round
|
||||
for (int64_t round = 0; round < num_rounds; round++) {
|
||||
int64_t total_rows = (round + 1) * batch_size;
|
||||
int64_t expected_valid_count =
|
||||
total_rows - (total_rows * null_percent / 100);
|
||||
|
||||
auto dataset = DataGen(schema,
|
||||
batch_size,
|
||||
42 + round,
|
||||
0,
|
||||
1,
|
||||
10,
|
||||
1,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
null_percent);
|
||||
|
||||
// Build logical to physical mapping for this batch
|
||||
int64_t base_physical = all_float_vectors.size() / dim;
|
||||
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
base_physical = all_sparse_vectors.size();
|
||||
}
|
||||
|
||||
auto valid_data_from_dataset = dataset.get_col_valid(vec);
|
||||
int64_t physical_idx = base_physical;
|
||||
for (size_t i = 0; i < valid_data_from_dataset.size(); i++) {
|
||||
if (valid_data_from_dataset[i]) {
|
||||
logical_to_physical.push_back(physical_idx);
|
||||
physical_idx++;
|
||||
} else {
|
||||
logical_to_physical.push_back(-1); // null
|
||||
}
|
||||
}
|
||||
|
||||
// Get original data directly from proto (sparse storage for nullable)
|
||||
// Data is stored sparsely - only valid vectors are in the proto
|
||||
if (data_type == DataType::VECTOR_FLOAT) {
|
||||
auto field_data = dataset.get_col(vec);
|
||||
auto& float_data = field_data->vectors().float_vector().data();
|
||||
all_float_vectors.insert(
|
||||
all_float_vectors.end(), float_data.begin(), float_data.end());
|
||||
} else if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
auto field_data = dataset.get_col(vec);
|
||||
auto& sparse_array = field_data->vectors().sparse_float_vector();
|
||||
for (int i = 0; i < sparse_array.contents_size(); i++) {
|
||||
auto& content = sparse_array.contents(i);
|
||||
auto row = CopyAndWrapSparseRow(content.data(), content.size());
|
||||
all_sparse_vectors.push_back(std::move(row));
|
||||
}
|
||||
}
|
||||
all_valid_data.insert(all_valid_data.end(),
|
||||
valid_data_from_dataset.begin(),
|
||||
valid_data_from_dataset.end());
|
||||
|
||||
auto offset = segment->PreInsert(batch_size);
|
||||
segment->Insert(offset,
|
||||
batch_size,
|
||||
dataset.row_ids_.data(),
|
||||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
auto& insert_record = segment->get_insert_record();
|
||||
ASSERT_TRUE(insert_record.is_valid_data_exist(vec));
|
||||
|
||||
auto valid_data_ptr = insert_record.get_data_base(vec);
|
||||
const auto& valid_data = valid_data_ptr->get_valid_data();
|
||||
|
||||
// Test search
|
||||
auto sr =
|
||||
segment_growing->Search(plan.get(), ph_group.get(), timestamp);
|
||||
|
||||
ASSERT_EQ(sr->total_nq_, num_queries);
|
||||
ASSERT_EQ(sr->unity_topK_, topk);
|
||||
|
||||
if (expected_valid_count == 0) {
|
||||
auto total_results = sr->get_total_result_count();
|
||||
EXPECT_EQ(total_results, 0)
|
||||
<< "Round " << round
|
||||
<< ": 100% null should return 0 results, but got "
|
||||
<< total_results;
|
||||
} else {
|
||||
// Verify search results don't contain null vectors
|
||||
for (size_t i = 0; i < sr->seg_offsets_.size(); i++) {
|
||||
auto seg_offset = sr->seg_offsets_[i];
|
||||
if (seg_offset < 0) {
|
||||
continue;
|
||||
}
|
||||
ASSERT_TRUE(valid_data[seg_offset])
|
||||
<< "Round " << round
|
||||
<< ": Search returned null vector at offset " << seg_offset;
|
||||
}
|
||||
}
|
||||
|
||||
auto vec_result = segment->bulk_subscript(
|
||||
nullptr, vec, sr->seg_offsets_.data(), sr->seg_offsets_.size());
|
||||
ASSERT_TRUE(vec_result != nullptr);
|
||||
|
||||
if (data_type == DataType::VECTOR_FLOAT) {
|
||||
auto& float_data = vec_result->vectors().float_vector();
|
||||
size_t valid_idx = 0;
|
||||
for (size_t i = 0; i < sr->seg_offsets_.size(); i++) {
|
||||
auto offset = sr->seg_offsets_[i];
|
||||
if (offset < 0) {
|
||||
continue; // Skip invalid offsets
|
||||
}
|
||||
auto physical_idx = logical_to_physical[offset];
|
||||
for (int d = 0; d < dim; d++) {
|
||||
float expected_val =
|
||||
all_float_vectors[physical_idx * dim + d];
|
||||
float actual_val = float_data.data(valid_idx * dim + d);
|
||||
ASSERT_FLOAT_EQ(expected_val, actual_val)
|
||||
<< "Round " << round << ": Mismatch at logical offset "
|
||||
<< offset << " dim " << d;
|
||||
}
|
||||
valid_idx++;
|
||||
}
|
||||
} else if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||
auto& sparse_data = vec_result->vectors().sparse_float_vector();
|
||||
size_t valid_idx = 0;
|
||||
for (size_t i = 0; i < sr->seg_offsets_.size(); i++) {
|
||||
auto offset = sr->seg_offsets_[i];
|
||||
if (offset < 0) {
|
||||
continue; // Skip invalid offsets
|
||||
}
|
||||
auto physical_idx = logical_to_physical[offset];
|
||||
auto& content = sparse_data.contents(valid_idx);
|
||||
auto retrieved_row =
|
||||
CopyAndWrapSparseRow(content.data(), content.size());
|
||||
const auto& expected_row = all_sparse_vectors[physical_idx];
|
||||
ASSERT_EQ(retrieved_row.size(), expected_row.size())
|
||||
<< "Round " << round
|
||||
<< ": Sparse vector size mismatch at logical offset "
|
||||
<< offset;
|
||||
for (size_t j = 0; j < retrieved_row.size(); j++) {
|
||||
ASSERT_EQ(retrieved_row[j].id, expected_row[j].id)
|
||||
<< "Round " << round
|
||||
<< ": Sparse vector id mismatch at logical offset "
|
||||
<< offset << " element " << j;
|
||||
ASSERT_FLOAT_EQ(retrieved_row[j].val, expected_row[j].val)
|
||||
<< "Round " << round
|
||||
<< ": Sparse vector val mismatch at logical offset "
|
||||
<< offset << " element " << j;
|
||||
}
|
||||
valid_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(GrowingTest, FillVectorArrayData) {
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
||||
|
||||
@ -500,10 +500,18 @@ SegmentInternalInterface::bulk_subscript_not_exist_field(
|
||||
const milvus::FieldMeta& field_meta, int64_t count) const {
|
||||
auto data_type = field_meta.get_data_type();
|
||||
if (IsVectorDataType(data_type)) {
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported added field type {}",
|
||||
field_meta.get_data_type()));
|
||||
AssertInfo(field_meta.is_nullable(),
|
||||
"Non-nullable vector field should not reach here");
|
||||
|
||||
auto result = CreateEmptyVectorDataArray(0, field_meta);
|
||||
|
||||
auto valid_data = result->mutable_valid_data();
|
||||
for (int64_t i = 0; i < count; ++i) {
|
||||
valid_data->Add(false);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
auto result = CreateEmptyScalarDataArray(count, field_meta);
|
||||
if (field_meta.default_value().has_value()) {
|
||||
auto res = result->mutable_valid_data()->mutable_data();
|
||||
|
||||
@ -434,6 +434,23 @@ CreateEmptyVectorDataArray(int64_t count, const FieldMeta& field_meta) {
|
||||
return data_array;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateEmptyVectorDataArray(int64_t count,
|
||||
int64_t valid_count,
|
||||
const void* valid_data,
|
||||
const FieldMeta& field_meta) {
|
||||
int64_t data_count = (field_meta.is_nullable() && valid_data != nullptr)
|
||||
? valid_count
|
||||
: count;
|
||||
auto data_array = CreateEmptyVectorDataArray(data_count, field_meta);
|
||||
if (field_meta.is_nullable() && valid_data != nullptr) {
|
||||
auto obj = data_array->mutable_valid_data();
|
||||
auto valid_data_bool = reinterpret_cast<const bool*>(valid_data);
|
||||
obj->Add(valid_data_bool, valid_data_bool + count);
|
||||
}
|
||||
return data_array;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateScalarDataArrayFrom(const void* data_raw,
|
||||
const void* valid_data,
|
||||
@ -444,7 +461,7 @@ CreateScalarDataArrayFrom(const void* data_raw,
|
||||
data_array->set_field_id(field_meta.get_id().get());
|
||||
data_array->set_type(static_cast<milvus::proto::schema::DataType>(
|
||||
field_meta.get_data_type()));
|
||||
if (field_meta.is_nullable()) {
|
||||
if (field_meta.is_nullable() && valid_data != nullptr) {
|
||||
auto valid_data_ = reinterpret_cast<const bool*>(valid_data);
|
||||
auto obj = data_array->mutable_valid_data();
|
||||
obj->Add(valid_data_, valid_data_ + count);
|
||||
@ -659,6 +676,22 @@ CreateVectorDataArrayFrom(const void* data_raw,
|
||||
return data_array;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateVectorDataArrayFrom(const void* data_raw,
|
||||
const void* valid_data,
|
||||
int64_t count,
|
||||
int64_t valid_count,
|
||||
const FieldMeta& field_meta) {
|
||||
auto data_array =
|
||||
CreateVectorDataArrayFrom(data_raw, valid_count, field_meta);
|
||||
if (field_meta.is_nullable() && valid_data != nullptr) {
|
||||
auto obj = data_array->mutable_valid_data();
|
||||
auto valid_data_bool = reinterpret_cast<const bool*>(valid_data);
|
||||
obj->Add(valid_data_bool, valid_data_bool + count);
|
||||
}
|
||||
return data_array;
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateDataArrayFrom(const void* data_raw,
|
||||
const void* valid_data,
|
||||
@ -691,6 +724,21 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
|
||||
AssertInfo(data_type == DataType(src_field_data->type()),
|
||||
"merge field data type not consistent");
|
||||
if (field_meta.is_vector()) {
|
||||
bool is_valid = true;
|
||||
if (nullable) {
|
||||
auto data = src_field_data->valid_data().data();
|
||||
auto obj = data_array->mutable_valid_data();
|
||||
is_valid = data[src_offset];
|
||||
*(obj->Add()) = is_valid;
|
||||
}
|
||||
|
||||
if (!is_valid) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t physical_offset =
|
||||
merge_base.getValidDataOffset(field_meta.get_id());
|
||||
|
||||
auto vector_array = data_array->mutable_vectors();
|
||||
auto dim = 0;
|
||||
if (!IsSparseFloatVectorDataType(data_type)) {
|
||||
@ -700,17 +748,19 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
auto data = VEC_FIELD_DATA(src_field_data, float).data();
|
||||
auto obj = vector_array->mutable_float_vector();
|
||||
obj->mutable_data()->Add(data + src_offset * dim,
|
||||
data + (src_offset + 1) * dim);
|
||||
obj->mutable_data()->Add(data + physical_offset * dim,
|
||||
data + (physical_offset + 1) * dim);
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) {
|
||||
auto data = VEC_FIELD_DATA(src_field_data, float16);
|
||||
auto obj = vector_array->mutable_float16_vector();
|
||||
obj->assign(data, dim * sizeof(float16));
|
||||
obj->assign(data + physical_offset * dim * sizeof(float16),
|
||||
dim * sizeof(float16));
|
||||
} else if (field_meta.get_data_type() ==
|
||||
DataType::VECTOR_BFLOAT16) {
|
||||
auto data = VEC_FIELD_DATA(src_field_data, bfloat16);
|
||||
auto obj = vector_array->mutable_bfloat16_vector();
|
||||
obj->assign(data, dim * sizeof(bfloat16));
|
||||
obj->assign(data + physical_offset * dim * sizeof(bfloat16),
|
||||
dim * sizeof(bfloat16));
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
|
||||
AssertInfo(
|
||||
dim % 8 == 0,
|
||||
@ -718,26 +768,28 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
|
||||
auto num_bytes = dim / 8;
|
||||
auto data = VEC_FIELD_DATA(src_field_data, binary);
|
||||
auto obj = vector_array->mutable_binary_vector();
|
||||
obj->assign(data + src_offset * num_bytes, num_bytes);
|
||||
obj->assign(data + physical_offset * num_bytes, num_bytes);
|
||||
} else if (field_meta.get_data_type() ==
|
||||
DataType::VECTOR_SPARSE_U32_F32) {
|
||||
auto src = src_field_data->vectors().sparse_float_vector();
|
||||
auto& src_vec = src_field_data->vectors().sparse_float_vector();
|
||||
auto dst = vector_array->mutable_sparse_float_vector();
|
||||
if (src.dim() > dst->dim()) {
|
||||
dst->set_dim(src.dim());
|
||||
if (src_vec.dim() > dst->dim()) {
|
||||
dst->set_dim(src_vec.dim());
|
||||
}
|
||||
vector_array->set_dim(dst->dim());
|
||||
*dst->mutable_contents() = src.contents();
|
||||
auto& src_contents = src_vec.contents(physical_offset);
|
||||
*(dst->mutable_contents()->Add()) = src_contents;
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_INT8) {
|
||||
auto data = VEC_FIELD_DATA(src_field_data, int8);
|
||||
auto obj = vector_array->mutable_int8_vector();
|
||||
obj->assign(data, dim * sizeof(int8));
|
||||
obj->assign(data + physical_offset * dim * sizeof(int8),
|
||||
dim * sizeof(int8));
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
auto data = src_field_data->vectors().vector_array();
|
||||
auto& data = src_field_data->vectors().vector_array();
|
||||
auto obj = vector_array->mutable_vector_array();
|
||||
obj->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
obj->CopyFrom(data);
|
||||
*(obj->mutable_data()->Add()) = data.data(physical_offset);
|
||||
} else {
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported datatype {}", data_type));
|
||||
|
||||
@ -55,6 +55,12 @@ CreateEmptyScalarDataArray(int64_t count, const FieldMeta& field_meta);
|
||||
std::unique_ptr<DataArray>
|
||||
CreateEmptyVectorDataArray(int64_t count, const FieldMeta& field_meta);
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateEmptyVectorDataArray(int64_t count,
|
||||
int64_t valid_count,
|
||||
const void* valid_data,
|
||||
const FieldMeta& field_meta);
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateScalarDataArrayFrom(const void* data_raw,
|
||||
const void* valid_data,
|
||||
@ -66,6 +72,13 @@ CreateVectorDataArrayFrom(const void* data_raw,
|
||||
int64_t count,
|
||||
const FieldMeta& field_meta);
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateVectorDataArrayFrom(const void* data_raw,
|
||||
const void* valid_data,
|
||||
int64_t count,
|
||||
int64_t valid_count,
|
||||
const FieldMeta& field_meta);
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
CreateDataArrayFrom(const void* data_raw,
|
||||
const void* valid_data,
|
||||
@ -77,6 +90,7 @@ struct MergeBase {
|
||||
private:
|
||||
std::map<FieldId, std::unique_ptr<milvus::DataArray>>* output_fields_data_;
|
||||
size_t offset_;
|
||||
std::map<FieldId, size_t> valid_data_offsets_;
|
||||
|
||||
public:
|
||||
MergeBase() {
|
||||
@ -93,6 +107,20 @@ struct MergeBase {
|
||||
return offset_;
|
||||
}
|
||||
|
||||
void
|
||||
setValidDataOffset(FieldId fieldId, size_t valid_offset) {
|
||||
valid_data_offsets_[fieldId] = valid_offset;
|
||||
}
|
||||
|
||||
size_t
|
||||
getValidDataOffset(FieldId fieldId) const {
|
||||
auto it = valid_data_offsets_.find(fieldId);
|
||||
if (it != valid_data_offsets_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return offset_;
|
||||
}
|
||||
|
||||
milvus::DataArray*
|
||||
get_field_data(FieldId fieldId) const {
|
||||
return (*output_fields_data_)[fieldId].get();
|
||||
|
||||
@ -94,3 +94,96 @@ TEST(Util_Segcore, GetDeleteBitmap) {
|
||||
delete_record.Query(res_view, insert_barrier, query_timestamp);
|
||||
ASSERT_EQ(res_view.count(), 0);
|
||||
}
|
||||
|
||||
TEST(Util_Segcore, CreateVectorDataArrayFromNullableVectors) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec = schema->AddDebugField(
|
||||
"embeddings", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2, true);
|
||||
auto& field_meta = (*schema)[vec];
|
||||
|
||||
int64_t dim = 16;
|
||||
int64_t total_count = 10;
|
||||
int64_t valid_count = 5;
|
||||
|
||||
std::vector<float> data(valid_count * dim);
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
data[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
std::unique_ptr<bool[]> valid_flags = std::make_unique<bool[]>(total_count);
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
if (i % 2 == 0) {
|
||||
valid_flags[i] = true;
|
||||
} else {
|
||||
valid_flags[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
auto result = CreateVectorDataArrayFrom(
|
||||
data.data(), valid_flags.get(), total_count, valid_count, field_meta);
|
||||
|
||||
ASSERT_TRUE(result->valid_data().size() > 0);
|
||||
ASSERT_EQ(result->valid_data().size(), total_count);
|
||||
ASSERT_EQ(result->vectors().float_vector().data_size(), valid_count * dim);
|
||||
}
|
||||
|
||||
TEST(Util_Segcore, MergeDataArrayWithNullableVectors) {
|
||||
using namespace milvus;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto vec = schema->AddDebugField(
|
||||
"embeddings", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2, true);
|
||||
auto& field_meta = (*schema)[vec];
|
||||
|
||||
int64_t dim = 16;
|
||||
int64_t total_count = 10;
|
||||
int64_t valid_count = 5;
|
||||
|
||||
std::vector<float> data(valid_count * dim);
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
data[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
std::unique_ptr<bool[]> valid_flags = std::make_unique<bool[]>(total_count);
|
||||
for (int64_t i = 0; i < total_count; ++i) {
|
||||
if (i % 2 == 0) {
|
||||
valid_flags[i] = true;
|
||||
} else {
|
||||
valid_flags[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
auto data_array = CreateVectorDataArrayFrom(
|
||||
data.data(), valid_flags.get(), total_count, valid_count, field_meta);
|
||||
|
||||
std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data;
|
||||
output_fields_data[vec] = std::move(data_array);
|
||||
|
||||
std::vector<MergeBase> merge_bases;
|
||||
merge_bases.emplace_back(&output_fields_data, 0);
|
||||
merge_bases.back().setValidDataOffset(vec, 0);
|
||||
merge_bases.emplace_back(&output_fields_data, 2);
|
||||
merge_bases.back().setValidDataOffset(vec, 1);
|
||||
merge_bases.emplace_back(&output_fields_data, 4);
|
||||
merge_bases.back().setValidDataOffset(vec, 2);
|
||||
merge_bases.emplace_back(&output_fields_data, 6);
|
||||
merge_bases.back().setValidDataOffset(vec, 3);
|
||||
merge_bases.emplace_back(&output_fields_data, 8);
|
||||
merge_bases.back().setValidDataOffset(vec, 4);
|
||||
|
||||
auto merged_result = MergeDataArray(merge_bases, field_meta);
|
||||
|
||||
ASSERT_TRUE(merged_result->valid_data().size() > 0);
|
||||
ASSERT_EQ(merged_result->valid_data().size(), 5);
|
||||
ASSERT_EQ(merged_result->vectors().float_vector().data_size(), 5 * dim);
|
||||
|
||||
ASSERT_TRUE(merged_result->valid_data(0));
|
||||
ASSERT_TRUE(merged_result->valid_data(1));
|
||||
ASSERT_TRUE(merged_result->valid_data(2));
|
||||
ASSERT_TRUE(merged_result->valid_data(3));
|
||||
ASSERT_TRUE(merged_result->valid_data(4));
|
||||
}
|
||||
|
||||
@ -541,6 +541,27 @@ ReduceHelper::GetSearchResultDataSlice(const int slice_index,
|
||||
|
||||
// set result offset to fill output fields data
|
||||
result_pairs[loc] = {&search_result->output_fields_data_, ki};
|
||||
|
||||
for (auto field_id : plan_->target_entries_) {
|
||||
auto& field_meta = plan_->schema_->operator[](field_id);
|
||||
if (field_meta.is_vector() && field_meta.is_nullable()) {
|
||||
auto it =
|
||||
search_result->output_fields_data_.find(field_id);
|
||||
if (it != search_result->output_fields_data_.end()) {
|
||||
auto& field_data = it->second;
|
||||
if (field_data->valid_data_size() > 0) {
|
||||
int64_t valid_idx = 0;
|
||||
for (int64_t i = 0; i < ki; ++i) {
|
||||
if (field_data->valid_data(i)) {
|
||||
valid_idx++;
|
||||
}
|
||||
}
|
||||
result_pairs[loc].setValidDataOffset(field_id,
|
||||
valid_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,32 @@
|
||||
|
||||
namespace milvus::segcore {
|
||||
|
||||
void
|
||||
StreamReducerHelper::SetNullableVectorValidDataOffsets(
|
||||
const std::map<FieldId, std::unique_ptr<milvus::DataArray>>&
|
||||
output_fields_data,
|
||||
int64_t ki,
|
||||
MergeBase& merge_base) {
|
||||
for (auto field_id : plan_->target_entries_) {
|
||||
auto& field_meta = plan_->schema_->operator[](field_id);
|
||||
if (field_meta.is_vector() && field_meta.is_nullable()) {
|
||||
auto it = output_fields_data.find(field_id);
|
||||
if (it != output_fields_data.end()) {
|
||||
auto& field_data = it->second;
|
||||
if (field_data->valid_data_size() > 0) {
|
||||
int64_t physical_offset = 0;
|
||||
for (int64_t j = 0; j < ki; ++j) {
|
||||
if (field_data->valid_data(j)) {
|
||||
physical_offset++;
|
||||
}
|
||||
}
|
||||
merge_base.setValidDataOffset(field_id, physical_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamReducerHelper::FillEntryData() {
|
||||
for (auto search_result : search_results_to_merge_) {
|
||||
@ -98,6 +124,10 @@ StreamReducerHelper::AssembleMergedResult() {
|
||||
}
|
||||
merge_output_data_bases[nq_base_offset + loc] = {
|
||||
&search_result->output_fields_data_, ki};
|
||||
SetNullableVectorValidDataOffsets(
|
||||
search_result->output_fields_data_,
|
||||
ki,
|
||||
merge_output_data_bases[nq_base_offset + loc]);
|
||||
new_result_offsets[nq_base_offset + loc] = loc;
|
||||
real_topKs[qi]++;
|
||||
}
|
||||
@ -127,6 +157,10 @@ StreamReducerHelper::AssembleMergedResult() {
|
||||
}
|
||||
merge_output_data_bases[nq_base_offset + loc] = {
|
||||
&merged_search_result->output_fields_data_, ki};
|
||||
SetNullableVectorValidDataOffsets(
|
||||
merged_search_result->output_fields_data_,
|
||||
ki,
|
||||
merge_output_data_bases[nq_base_offset + loc]);
|
||||
new_result_offsets[nq_base_offset + loc] = loc;
|
||||
real_topKs[qi]++;
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "query/PlanImpl.h"
|
||||
#include "common/QueryResult.h"
|
||||
#include "segcore/ReduceStructure.h"
|
||||
#include "segcore/Utils.h"
|
||||
#include "common/EasyAssert.h"
|
||||
|
||||
namespace milvus::segcore {
|
||||
@ -207,6 +208,13 @@ class StreamReducerHelper {
|
||||
void
|
||||
CleanReduceStatus();
|
||||
|
||||
void
|
||||
SetNullableVectorValidDataOffsets(
|
||||
const std::map<FieldId, std::unique_ptr<milvus::DataArray>>&
|
||||
output_fields_data,
|
||||
int64_t ki,
|
||||
MergeBase& merge_base);
|
||||
|
||||
std::unique_ptr<MergedSearchResult> merged_search_result;
|
||||
milvus::query::Plan* plan_;
|
||||
std::vector<int64_t> slice_nqs_;
|
||||
|
||||
@ -107,6 +107,16 @@ DefaultValueChunkTranslator::estimated_byte_size_of_cell(
|
||||
case milvus::DataType::ARRAY:
|
||||
value_size = sizeof(Array);
|
||||
break;
|
||||
case milvus::DataType::VECTOR_FLOAT:
|
||||
case milvus::DataType::VECTOR_BINARY:
|
||||
case milvus::DataType::VECTOR_FLOAT16:
|
||||
case milvus::DataType::VECTOR_BFLOAT16:
|
||||
case milvus::DataType::VECTOR_INT8:
|
||||
case milvus::DataType::VECTOR_SPARSE_U32_F32:
|
||||
AssertInfo(field_meta_.is_nullable(),
|
||||
"only nullable vector fields can be dynamically added");
|
||||
value_size = 0;
|
||||
break;
|
||||
default:
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
"unsupported default value data type {}",
|
||||
@ -128,8 +138,15 @@ DefaultValueChunkTranslator::get_cells(
|
||||
AssertInfo(cids.size() == 1 && cids[0] == 0,
|
||||
"DefaultValueChunkTranslator only supports one cell");
|
||||
auto num_rows = meta_.num_rows_until_chunk_[1];
|
||||
auto builder =
|
||||
milvus::storage::CreateArrowBuilder(field_meta_.get_data_type());
|
||||
auto data_type = field_meta_.get_data_type();
|
||||
std::shared_ptr<arrow::ArrayBuilder> builder;
|
||||
if (IsVectorDataType(data_type)) {
|
||||
AssertInfo(field_meta_.is_nullable(),
|
||||
"only nullable vector fields can be dynamically added");
|
||||
builder = std::make_shared<arrow::BinaryBuilder>();
|
||||
} else {
|
||||
builder = milvus::storage::CreateArrowBuilder(data_type);
|
||||
}
|
||||
arrow::Status ast;
|
||||
if (field_meta_.default_value().has_value()) {
|
||||
ast = builder->Reserve(num_rows);
|
||||
|
||||
@ -143,20 +143,54 @@ InterimSealedIndexTranslator::get_cells(
|
||||
}
|
||||
|
||||
auto num_chunk = vec_data_->num_chunks();
|
||||
const auto& offset_mapping = vec_data_->GetOffsetMapping();
|
||||
bool nullable = offset_mapping.IsEnabled();
|
||||
const auto& valid_count_per_chunk =
|
||||
nullable ? vec_data_->GetValidCountPerChunk() : std::vector<int64_t>{};
|
||||
|
||||
int64_t total_valid_count =
|
||||
nullable ? offset_mapping.GetValidCount() : vec_data_->NumRows();
|
||||
|
||||
if (total_valid_count == 0) {
|
||||
if (nullable) {
|
||||
const auto& valid_data = vec_data_->GetValidData();
|
||||
vec_index->BuildValidData(valid_data.data(), valid_data.size());
|
||||
}
|
||||
std::vector<std::pair<cid_t, std::unique_ptr<milvus::index::IndexBase>>>
|
||||
result;
|
||||
result.emplace_back(std::make_pair(0, std::move(vec_index)));
|
||||
return result;
|
||||
}
|
||||
|
||||
bool first_build = true;
|
||||
for (int i = 0; i < num_chunk; ++i) {
|
||||
auto pw = vec_data_->GetChunk(nullptr, i);
|
||||
auto chunk = pw.get();
|
||||
auto dataset = knowhere::GenDataSet(
|
||||
vec_data_->chunk_row_nums(i), dim_, chunk->Data());
|
||||
|
||||
int64_t actual_row_count =
|
||||
nullable ? valid_count_per_chunk[i] : vec_data_->chunk_row_nums(i);
|
||||
|
||||
if (actual_row_count == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto dataset =
|
||||
knowhere::GenDataSet(actual_row_count, dim_, chunk->Data());
|
||||
dataset->SetIsOwner(false);
|
||||
dataset->SetIsSparse(is_sparse_);
|
||||
|
||||
if (i == 0) {
|
||||
if (first_build) {
|
||||
vec_index->BuildWithDataset(dataset, build_config_);
|
||||
first_build = false;
|
||||
} else {
|
||||
vec_index->AddWithDataset(dataset, build_config_);
|
||||
}
|
||||
}
|
||||
|
||||
if (nullable) {
|
||||
const auto& valid_data = vec_data_->GetValidData();
|
||||
vec_index->BuildValidData(valid_data.data(), valid_data.size());
|
||||
}
|
||||
std::vector<std::pair<cid_t, std::unique_ptr<milvus::index::IndexBase>>>
|
||||
result;
|
||||
result.emplace_back(std::make_pair(0, std::move(vec_index)));
|
||||
|
||||
@ -763,7 +763,80 @@ TEST(storage, InsertDataFloatVector) {
|
||||
ASSERT_EQ(data, new_data);
|
||||
}
|
||||
|
||||
TEST(storage, InsertDataSparseFloat) {
|
||||
TEST(storage, InsertDataFloatVectorNullable) {
|
||||
int DIM = 4;
|
||||
int num_rows = 100;
|
||||
|
||||
for (int null_percent : {0, 20, 100}) {
|
||||
int valid_count = num_rows * (100 - null_percent) / 100;
|
||||
bool is_nullable = true;
|
||||
|
||||
std::vector<float> data(valid_count * DIM);
|
||||
for (int i = 0; i < valid_count * DIM; ++i) {
|
||||
data[i] = static_cast<float>(i) * 0.5f;
|
||||
}
|
||||
|
||||
FieldDataPtr field_data;
|
||||
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
|
||||
for (int i = 0; i < valid_count; ++i) {
|
||||
valid_data[i >> 3] |= (1 << (i & 0x07));
|
||||
}
|
||||
|
||||
field_data = milvus::storage::CreateFieldData(
|
||||
storage::DataType::VECTOR_FLOAT, DataType::NONE, true, DIM);
|
||||
auto field_data_impl =
|
||||
std::dynamic_pointer_cast<milvus::FieldData<milvus::FloatVector>>(
|
||||
field_data);
|
||||
field_data_impl->FillFieldData(
|
||||
data.data(), valid_data.data(), num_rows, 0);
|
||||
|
||||
ASSERT_EQ(field_data->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
|
||||
ASSERT_EQ(field_data->get_null_count(), num_rows - valid_count);
|
||||
ASSERT_EQ(field_data->IsNullable(), is_nullable);
|
||||
|
||||
auto payload_reader =
|
||||
std::make_shared<milvus::storage::PayloadReader>(field_data);
|
||||
storage::InsertData insert_data(payload_reader);
|
||||
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
|
||||
insert_data.SetFieldDataMeta(field_data_meta);
|
||||
insert_data.SetTimestamps(0, 100);
|
||||
|
||||
auto serialized_bytes =
|
||||
insert_data.Serialize(storage::StorageType::Remote);
|
||||
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
|
||||
[&](uint8_t*) {});
|
||||
auto new_insert_data = storage::DeserializeFileData(
|
||||
serialized_data_ptr, serialized_bytes.size());
|
||||
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
|
||||
ASSERT_EQ(new_insert_data->GetTimeRage(),
|
||||
std::make_pair(Timestamp(0), Timestamp(100)));
|
||||
|
||||
auto new_payload = new_insert_data->GetFieldData();
|
||||
|
||||
ASSERT_EQ(new_payload->get_data_type(),
|
||||
storage::DataType::VECTOR_FLOAT);
|
||||
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(new_payload->get_valid_rows(), valid_count);
|
||||
ASSERT_EQ(new_payload->get_null_count(), num_rows - valid_count);
|
||||
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
|
||||
|
||||
int valid_idx = 0;
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
if (new_payload->is_valid(i)) {
|
||||
// RawValue takes logical offset, internally converts to physical
|
||||
auto vec_ptr =
|
||||
static_cast<const float*>(new_payload->RawValue(i));
|
||||
for (int j = 0; j < DIM; ++j) {
|
||||
ASSERT_FLOAT_EQ(vec_ptr[j], data[valid_idx * DIM + j]);
|
||||
}
|
||||
valid_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(storage, InsertDataSparseFloatVector) {
|
||||
auto n_rows = 100;
|
||||
auto vecs = milvus::segcore::GenerateRandomSparseFloatVector(
|
||||
n_rows, kTestSparseDim, kTestSparseVectorDensity);
|
||||
@ -810,6 +883,75 @@ TEST(storage, InsertDataSparseFloat) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(storage, InsertDataSparseFloatVectorNullable) {
|
||||
int num_rows = 100;
|
||||
|
||||
for (int null_percent : {0, 20, 100}) {
|
||||
int valid_count = num_rows * (100 - null_percent) / 100;
|
||||
bool is_nullable = true;
|
||||
auto vecs = milvus::segcore::GenerateRandomSparseFloatVector(
|
||||
valid_count, kTestSparseDim, kTestSparseVectorDensity);
|
||||
|
||||
FieldDataPtr field_data;
|
||||
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
|
||||
for (int i = 0; i < valid_count; ++i) {
|
||||
valid_data[i >> 3] |= (1 << (i & 0x07));
|
||||
}
|
||||
|
||||
field_data =
|
||||
milvus::storage::CreateFieldData(DataType::VECTOR_SPARSE_U32_F32,
|
||||
DataType::NONE,
|
||||
true,
|
||||
kTestSparseDim,
|
||||
num_rows);
|
||||
|
||||
auto field_data_impl = std::dynamic_pointer_cast<
|
||||
milvus::FieldData<milvus::SparseFloatVector>>(field_data);
|
||||
field_data_impl->FillFieldData(
|
||||
vecs.get(), valid_data.data(), num_rows, 0);
|
||||
|
||||
ASSERT_EQ(field_data->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
|
||||
ASSERT_EQ(field_data->IsNullable(), is_nullable);
|
||||
|
||||
auto payload_reader =
|
||||
std::make_shared<milvus::storage::PayloadReader>(field_data);
|
||||
storage::InsertData insert_data(payload_reader);
|
||||
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
|
||||
insert_data.SetFieldDataMeta(field_data_meta);
|
||||
insert_data.SetTimestamps(0, 100);
|
||||
|
||||
auto serialized_bytes =
|
||||
insert_data.Serialize(storage::StorageType::Remote);
|
||||
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
|
||||
[&](uint8_t*) {});
|
||||
auto new_insert_data = storage::DeserializeFileData(
|
||||
serialized_data_ptr, serialized_bytes.size());
|
||||
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
|
||||
|
||||
auto new_payload = new_insert_data->GetFieldData();
|
||||
ASSERT_TRUE(new_payload->get_data_type() ==
|
||||
storage::DataType::VECTOR_SPARSE_U32_F32);
|
||||
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
|
||||
|
||||
int valid_idx = 0;
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
if (new_payload->is_valid(i)) {
|
||||
auto& original = vecs[valid_idx];
|
||||
auto new_vec = static_cast<const knowhere::sparse::SparseRow<
|
||||
milvus::SparseValueType>*>(new_payload->RawValue(i));
|
||||
ASSERT_EQ(original.size(), new_vec->size());
|
||||
for (size_t j = 0; j < original.size(); ++j) {
|
||||
ASSERT_EQ(original[j].id, (*new_vec)[j].id);
|
||||
ASSERT_EQ(original[j].val, (*new_vec)[j].val);
|
||||
}
|
||||
valid_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(storage, InsertDataBinaryVector) {
|
||||
std::vector<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int DIM = 16;
|
||||
@ -841,6 +983,72 @@ TEST(storage, InsertDataBinaryVector) {
|
||||
ASSERT_EQ(data, new_data);
|
||||
}
|
||||
|
||||
TEST(storage, InsertDataBinaryVectorNullable) {
|
||||
int DIM = 128;
|
||||
int num_rows = 100;
|
||||
|
||||
for (int null_percent : {0, 20, 100}) {
|
||||
int valid_count = num_rows * (100 - null_percent) / 100;
|
||||
bool is_nullable = true;
|
||||
|
||||
std::vector<uint8_t> data(valid_count * DIM / 8);
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
data[i] = static_cast<uint8_t>(i % 256);
|
||||
}
|
||||
|
||||
FieldDataPtr field_data;
|
||||
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
|
||||
for (int i = 0; i < valid_count; ++i) {
|
||||
valid_data[i >> 3] |= (1 << (i & 0x07));
|
||||
}
|
||||
|
||||
field_data = milvus::storage::CreateFieldData(
|
||||
storage::DataType::VECTOR_BINARY, DataType::NONE, true, DIM);
|
||||
auto field_data_impl =
|
||||
std::dynamic_pointer_cast<milvus::FieldData<milvus::BinaryVector>>(
|
||||
field_data);
|
||||
field_data_impl->FillFieldData(
|
||||
data.data(), valid_data.data(), num_rows, 0);
|
||||
|
||||
ASSERT_EQ(field_data->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
|
||||
ASSERT_EQ(field_data->IsNullable(), is_nullable);
|
||||
|
||||
auto payload_reader =
|
||||
std::make_shared<milvus::storage::PayloadReader>(field_data);
|
||||
storage::InsertData insert_data(payload_reader);
|
||||
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
|
||||
insert_data.SetFieldDataMeta(field_data_meta);
|
||||
insert_data.SetTimestamps(0, 100);
|
||||
|
||||
auto serialized_bytes =
|
||||
insert_data.Serialize(storage::StorageType::Remote);
|
||||
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
|
||||
[&](uint8_t*) {});
|
||||
auto new_insert_data = storage::DeserializeFileData(
|
||||
serialized_data_ptr, serialized_bytes.size());
|
||||
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
|
||||
|
||||
auto new_payload = new_insert_data->GetFieldData();
|
||||
ASSERT_EQ(new_payload->get_data_type(),
|
||||
storage::DataType::VECTOR_BINARY);
|
||||
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
|
||||
|
||||
int valid_idx = 0;
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
if (new_payload->is_valid(i)) {
|
||||
auto vec_ptr =
|
||||
static_cast<const uint8_t*>(new_payload->RawValue(i));
|
||||
for (int j = 0; j < DIM / 8; ++j) {
|
||||
ASSERT_EQ(vec_ptr[j], data[valid_idx * DIM / 8 + j]);
|
||||
}
|
||||
valid_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(storage, InsertDataFloat16Vector) {
|
||||
std::vector<float16> data = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
int DIM = 2;
|
||||
@ -874,6 +1082,72 @@ TEST(storage, InsertDataFloat16Vector) {
|
||||
ASSERT_EQ(data, new_data);
|
||||
}
|
||||
|
||||
TEST(storage, InsertDataFloat16VectorNullable) {
|
||||
int DIM = 4;
|
||||
int num_rows = 100;
|
||||
|
||||
for (int null_percent : {0, 20, 100}) {
|
||||
int valid_count = num_rows * (100 - null_percent) / 100;
|
||||
bool is_nullable = true;
|
||||
|
||||
std::vector<float16> data(valid_count * DIM);
|
||||
for (int i = 0; i < valid_count * DIM; ++i) {
|
||||
data[i] = static_cast<float16>(i * 0.5f);
|
||||
}
|
||||
|
||||
FieldDataPtr field_data;
|
||||
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
|
||||
for (int i = 0; i < valid_count; ++i) {
|
||||
valid_data[i >> 3] |= (1 << (i & 0x07));
|
||||
}
|
||||
|
||||
field_data = milvus::storage::CreateFieldData(
|
||||
storage::DataType::VECTOR_FLOAT16, DataType::NONE, true, DIM);
|
||||
auto field_data_impl =
|
||||
std::dynamic_pointer_cast<milvus::FieldData<milvus::Float16Vector>>(
|
||||
field_data);
|
||||
field_data_impl->FillFieldData(
|
||||
data.data(), valid_data.data(), num_rows, 0);
|
||||
|
||||
ASSERT_EQ(field_data->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
|
||||
ASSERT_EQ(field_data->IsNullable(), is_nullable);
|
||||
|
||||
auto payload_reader =
|
||||
std::make_shared<milvus::storage::PayloadReader>(field_data);
|
||||
storage::InsertData insert_data(payload_reader);
|
||||
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
|
||||
insert_data.SetFieldDataMeta(field_data_meta);
|
||||
insert_data.SetTimestamps(0, 100);
|
||||
|
||||
auto serialized_bytes =
|
||||
insert_data.Serialize(storage::StorageType::Remote);
|
||||
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
|
||||
[&](uint8_t*) {});
|
||||
auto new_insert_data = storage::DeserializeFileData(
|
||||
serialized_data_ptr, serialized_bytes.size());
|
||||
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
|
||||
|
||||
auto new_payload = new_insert_data->GetFieldData();
|
||||
ASSERT_EQ(new_payload->get_data_type(),
|
||||
storage::DataType::VECTOR_FLOAT16);
|
||||
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
|
||||
|
||||
int valid_idx = 0;
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
if (new_payload->is_valid(i)) {
|
||||
auto vec_ptr =
|
||||
static_cast<const float16*>(new_payload->RawValue(i));
|
||||
for (int j = 0; j < DIM; ++j) {
|
||||
ASSERT_EQ(vec_ptr[j], data[valid_idx * DIM + j]);
|
||||
}
|
||||
valid_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(storage, IndexData) {
|
||||
std::vector<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
storage::IndexData index_data(data.data(), data.size());
|
||||
|
||||
@ -536,7 +536,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_internal(const Config& config) {
|
||||
int batch_size = batch_files.size();
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
auto field_data = field_datas[i].get()->GetFieldData();
|
||||
num_rows += uint32_t(field_data->get_num_rows());
|
||||
num_rows += uint32_t(field_data->get_valid_rows());
|
||||
cache_raw_data_to_disk_common<DataType>(
|
||||
field_data,
|
||||
local_chunk_manager,
|
||||
@ -634,7 +634,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_common(
|
||||
auto sparse_rows =
|
||||
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
|
||||
field_data->Data());
|
||||
for (size_t i = 0; i < field_data->Length(); ++i) {
|
||||
for (size_t i = 0; i < field_data->get_valid_rows(); ++i) {
|
||||
auto row = sparse_rows[i];
|
||||
auto row_byte_size = row.data_byte_size();
|
||||
uint32_t nnz = row.size();
|
||||
@ -689,7 +689,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_common(
|
||||
} else {
|
||||
dim = field_data->get_dim();
|
||||
auto data_size =
|
||||
field_data->get_num_rows() * milvus::GetVecRowSize<DataType>(dim);
|
||||
field_data->get_valid_rows() * milvus::GetVecRowSize<DataType>(dim);
|
||||
local_chunk_manager->Write(local_data_path,
|
||||
write_offset,
|
||||
const_cast<void*>(field_data->Data()),
|
||||
@ -761,7 +761,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_storage_v2(const Config& config) {
|
||||
fs_);
|
||||
}
|
||||
for (auto& field_data : field_datas) {
|
||||
num_rows += uint32_t(field_data->get_num_rows());
|
||||
num_rows += uint32_t(field_data->get_valid_rows());
|
||||
cache_raw_data_to_disk_common<T>(field_data,
|
||||
local_chunk_manager,
|
||||
local_data_path,
|
||||
|
||||
@ -523,6 +523,262 @@ TEST_F(DiskAnnFileManagerTest, CacheOptFieldToDiskOnlyOneCategory) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DiskAnnFileManagerTest, CacheRawDataToDiskNullableVector) {
|
||||
const int64_t collection_id = 1;
|
||||
const int64_t partition_id = 2;
|
||||
const int64_t segment_id = 3;
|
||||
const int64_t field_id = 100;
|
||||
const int64_t dim = 128;
|
||||
const int64_t num_rows = 1000;
|
||||
|
||||
struct VectorTypeInfo {
|
||||
DataType data_type;
|
||||
std::string type_name;
|
||||
size_t element_size;
|
||||
bool is_sparse;
|
||||
};
|
||||
|
||||
std::vector<VectorTypeInfo> vector_types = {
|
||||
{DataType::VECTOR_FLOAT, "FLOAT", sizeof(float), false},
|
||||
{DataType::VECTOR_FLOAT16, "FLOAT16", sizeof(knowhere::fp16), false},
|
||||
{DataType::VECTOR_BFLOAT16, "BFLOAT16", sizeof(knowhere::bf16), false},
|
||||
{DataType::VECTOR_INT8, "INT8", sizeof(int8_t), false},
|
||||
{DataType::VECTOR_BINARY, "BINARY", dim / 8, false},
|
||||
{DataType::VECTOR_SPARSE_U32_F32, "SPARSE", 0, true}};
|
||||
|
||||
for (const auto& vec_type : vector_types) {
|
||||
for (int null_percent : {0, 20, 100}) {
|
||||
int64_t valid_count = num_rows * (100 - null_percent) / 100;
|
||||
|
||||
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
|
||||
for (int64_t i = 0; i < valid_count; ++i) {
|
||||
valid_data[i >> 3] |= (1 << (i & 0x07));
|
||||
}
|
||||
|
||||
FieldDataPtr field_data;
|
||||
std::vector<uint8_t> vec_data;
|
||||
std::unique_ptr<knowhere::sparse::SparseRow<float>[]> sparse_vecs;
|
||||
|
||||
if (vec_type.is_sparse) {
|
||||
const int64_t sparse_dim = 1000;
|
||||
const float sparse_density = 0.1;
|
||||
sparse_vecs = milvus::segcore::GenerateRandomSparseFloatVector(
|
||||
valid_count, sparse_dim, sparse_density);
|
||||
|
||||
field_data =
|
||||
storage::CreateFieldData(DataType::VECTOR_SPARSE_U32_F32,
|
||||
DataType::NONE,
|
||||
true,
|
||||
sparse_dim,
|
||||
num_rows);
|
||||
auto field_data_impl = std::dynamic_pointer_cast<
|
||||
milvus::FieldData<milvus::SparseFloatVector>>(field_data);
|
||||
field_data_impl->FillFieldData(
|
||||
sparse_vecs.get(), valid_data.data(), num_rows, 0);
|
||||
} else {
|
||||
if (vec_type.data_type == DataType::VECTOR_BINARY) {
|
||||
vec_data.resize(valid_count * dim / 8);
|
||||
} else {
|
||||
vec_data.resize(valid_count * dim * vec_type.element_size);
|
||||
}
|
||||
for (size_t i = 0; i < vec_data.size(); ++i) {
|
||||
vec_data[i] = static_cast<uint8_t>(i % 256);
|
||||
}
|
||||
|
||||
field_data = storage::CreateFieldData(
|
||||
vec_type.data_type, DataType::NONE, true, dim);
|
||||
|
||||
if (vec_type.data_type == DataType::VECTOR_FLOAT) {
|
||||
auto impl = std::dynamic_pointer_cast<
|
||||
milvus::FieldData<milvus::FloatVector>>(field_data);
|
||||
impl->FillFieldData(
|
||||
vec_data.data(), valid_data.data(), num_rows, 0);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_FLOAT16) {
|
||||
auto impl = std::dynamic_pointer_cast<
|
||||
milvus::FieldData<milvus::Float16Vector>>(field_data);
|
||||
impl->FillFieldData(
|
||||
vec_data.data(), valid_data.data(), num_rows, 0);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_BFLOAT16) {
|
||||
auto impl = std::dynamic_pointer_cast<
|
||||
milvus::FieldData<milvus::BFloat16Vector>>(field_data);
|
||||
impl->FillFieldData(
|
||||
vec_data.data(), valid_data.data(), num_rows, 0);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_INT8) {
|
||||
auto impl = std::dynamic_pointer_cast<
|
||||
milvus::FieldData<milvus::Int8Vector>>(field_data);
|
||||
impl->FillFieldData(
|
||||
vec_data.data(), valid_data.data(), num_rows, 0);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_BINARY) {
|
||||
auto impl = std::dynamic_pointer_cast<
|
||||
milvus::FieldData<milvus::BinaryVector>>(field_data);
|
||||
impl->FillFieldData(
|
||||
vec_data.data(), valid_data.data(), num_rows, 0);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(field_data->get_num_rows(), num_rows);
|
||||
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
|
||||
|
||||
auto payload_reader =
|
||||
std::make_shared<milvus::storage::PayloadReader>(field_data);
|
||||
storage::InsertData insert_data(payload_reader);
|
||||
FieldDataMeta field_data_meta = {
|
||||
collection_id, partition_id, segment_id, field_id};
|
||||
insert_data.SetFieldDataMeta(field_data_meta);
|
||||
insert_data.SetTimestamps(0, 100);
|
||||
|
||||
auto serialized_data =
|
||||
insert_data.Serialize(storage::StorageType::Remote);
|
||||
|
||||
std::string insert_file_path = "/tmp/diskann/nullable_" +
|
||||
vec_type.type_name + "_" +
|
||||
std::to_string(null_percent);
|
||||
boost::filesystem::remove_all(insert_file_path);
|
||||
cm_->Write(insert_file_path,
|
||||
serialized_data.data(),
|
||||
serialized_data.size());
|
||||
|
||||
if (vec_type.is_sparse) {
|
||||
int64_t file_size = cm_->Size(insert_file_path);
|
||||
std::vector<uint8_t> buffer(file_size);
|
||||
cm_->Read(insert_file_path, buffer.data(), file_size);
|
||||
|
||||
std::shared_ptr<uint8_t[]> serialized_data_ptr(
|
||||
buffer.data(), [&](uint8_t*) {});
|
||||
auto new_insert_data = storage::DeserializeFileData(
|
||||
serialized_data_ptr, buffer.size());
|
||||
ASSERT_EQ(new_insert_data->GetCodecType(),
|
||||
storage::InsertDataType);
|
||||
|
||||
auto new_payload = new_insert_data->GetFieldData();
|
||||
ASSERT_TRUE(new_payload->get_data_type() ==
|
||||
DataType::VECTOR_SPARSE_U32_F32);
|
||||
ASSERT_EQ(new_payload->get_num_rows(), num_rows)
|
||||
<< "num_rows mismatch for " << vec_type.type_name
|
||||
<< " with null_percent=" << null_percent;
|
||||
ASSERT_EQ(new_payload->get_valid_rows(), valid_count)
|
||||
<< "valid_rows mismatch for " << vec_type.type_name
|
||||
<< " with null_percent=" << null_percent;
|
||||
ASSERT_TRUE(new_payload->IsNullable());
|
||||
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
if (i < valid_count) {
|
||||
ASSERT_TRUE(new_payload->is_valid(i))
|
||||
<< "Row " << i
|
||||
<< " should be valid for null_percent="
|
||||
<< null_percent;
|
||||
|
||||
auto original = &sparse_vecs[i];
|
||||
auto new_vec =
|
||||
static_cast<const knowhere::sparse::SparseRow<
|
||||
milvus::SparseValueType>*>(
|
||||
new_payload->RawValue(i));
|
||||
ASSERT_EQ(original->size(), new_vec->size())
|
||||
<< "Size mismatch at row " << i
|
||||
<< " for null_percent=" << null_percent;
|
||||
|
||||
for (size_t j = 0; j < original->size(); ++j) {
|
||||
ASSERT_EQ((*original)[j].id, (*new_vec)[j].id)
|
||||
<< "ID mismatch at row " << i << ", element "
|
||||
<< j << " for null_percent=" << null_percent;
|
||||
ASSERT_EQ((*original)[j].val, (*new_vec)[j].val)
|
||||
<< "Value mismatch at row " << i << ", element "
|
||||
<< j << " for null_percent=" << null_percent;
|
||||
}
|
||||
} else {
|
||||
ASSERT_FALSE(new_payload->is_valid(i))
|
||||
<< "Row " << i
|
||||
<< " should be null for null_percent="
|
||||
<< null_percent;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
IndexMeta index_meta = {segment_id,
|
||||
field_id,
|
||||
1000,
|
||||
1,
|
||||
"test",
|
||||
"vec_field",
|
||||
vec_type.data_type,
|
||||
dim};
|
||||
auto file_manager = std::make_shared<DiskFileManagerImpl>(
|
||||
storage::FileManagerContext(
|
||||
field_data_meta, index_meta, cm_, fs_));
|
||||
|
||||
milvus::Config config;
|
||||
config[INSERT_FILES_KEY] =
|
||||
std::vector<std::string>{insert_file_path};
|
||||
|
||||
std::string local_data_path;
|
||||
if (vec_type.data_type == DataType::VECTOR_FLOAT) {
|
||||
local_data_path =
|
||||
file_manager->CacheRawDataToDisk<float>(config);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_INT8) {
|
||||
local_data_path =
|
||||
file_manager->CacheRawDataToDisk<int8_t>(config);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_FLOAT16) {
|
||||
local_data_path =
|
||||
file_manager->CacheRawDataToDisk<knowhere::fp16>(
|
||||
config);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_BFLOAT16) {
|
||||
local_data_path =
|
||||
file_manager->CacheRawDataToDisk<knowhere::bf16>(
|
||||
config);
|
||||
} else if (vec_type.data_type == DataType::VECTOR_BINARY) {
|
||||
local_data_path =
|
||||
file_manager->CacheRawDataToDisk<uint8_t>(config);
|
||||
}
|
||||
|
||||
ASSERT_FALSE(local_data_path.empty())
|
||||
<< "Failed for " << vec_type.type_name
|
||||
<< " with null_percent=" << null_percent;
|
||||
|
||||
auto local_chunk_manager =
|
||||
LocalChunkManagerSingleton::GetInstance().GetChunkManager();
|
||||
uint32_t read_num_rows = 0;
|
||||
uint32_t read_dim = 0;
|
||||
local_chunk_manager->Read(
|
||||
local_data_path, 0, &read_num_rows, sizeof(read_num_rows));
|
||||
local_chunk_manager->Read(local_data_path,
|
||||
sizeof(read_num_rows),
|
||||
&read_dim,
|
||||
sizeof(read_dim));
|
||||
|
||||
EXPECT_EQ(read_num_rows, valid_count)
|
||||
<< "Mismatch for " << vec_type.type_name
|
||||
<< " with null_percent=" << null_percent;
|
||||
EXPECT_EQ(read_dim, dim);
|
||||
|
||||
size_t bytes_per_vector =
|
||||
(vec_type.data_type == DataType::VECTOR_BINARY)
|
||||
? (dim / 8)
|
||||
: (dim * vec_type.element_size);
|
||||
auto data_size = read_num_rows * bytes_per_vector;
|
||||
std::vector<uint8_t> buffer(data_size);
|
||||
local_chunk_manager->Read(
|
||||
local_data_path,
|
||||
sizeof(read_num_rows) + sizeof(read_dim),
|
||||
buffer.data(),
|
||||
data_size);
|
||||
|
||||
EXPECT_EQ(buffer.size(), vec_data.size())
|
||||
<< "Data size mismatch for " << vec_type.type_name;
|
||||
for (size_t i = 0; i < std::min(buffer.size(), vec_data.size());
|
||||
++i) {
|
||||
EXPECT_EQ(buffer[i], vec_data[i])
|
||||
<< "Data mismatch at byte " << i << " for "
|
||||
<< vec_type.type_name
|
||||
<< " with null_percent=" << null_percent;
|
||||
}
|
||||
|
||||
local_chunk_manager->Remove(local_data_path);
|
||||
}
|
||||
|
||||
cm_->Remove(insert_file_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DiskAnnFileManagerTest, FileCleanup) {
|
||||
std::string local_index_file_path;
|
||||
std::string local_text_index_file_path;
|
||||
|
||||
@ -336,9 +336,13 @@ BaseEventData::Serialize() {
|
||||
auto row = static_cast<
|
||||
const knowhere::sparse::SparseRow<SparseValueType>*>(
|
||||
field_data->RawValue(offset));
|
||||
payload_writer->add_one_binary_payload(
|
||||
static_cast<const uint8_t*>(row->data()),
|
||||
row->data_byte_size());
|
||||
if (row) {
|
||||
payload_writer->add_one_binary_payload(
|
||||
static_cast<const uint8_t*>(row->data()),
|
||||
row->data_byte_size());
|
||||
} else {
|
||||
payload_writer->add_one_binary_payload(nullptr, -1);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -71,13 +71,44 @@ PayloadReader::init(const uint8_t* data, int length, bool is_field_data) {
|
||||
auto file_meta = arrow_reader->parquet_reader()->metadata();
|
||||
|
||||
// dim is unused for sparse float vector
|
||||
dim_ =
|
||||
(IsVectorDataType(column_type_) &&
|
||||
!IsVectorArrayDataType(column_type_) &&
|
||||
!IsSparseFloatVectorDataType(column_type_))
|
||||
? GetDimensionFromFileMetaData(
|
||||
file_meta->schema()->Column(column_index), column_type_)
|
||||
: 1;
|
||||
// For nullable vectors, dim is stored in Arrow schema metadata
|
||||
if (IsVectorDataType(column_type_) &&
|
||||
!IsVectorArrayDataType(column_type_) &&
|
||||
!IsSparseFloatVectorDataType(column_type_)) {
|
||||
if (nullable_) {
|
||||
std::shared_ptr<arrow::Schema> arrow_schema;
|
||||
auto st = arrow_reader->GetSchema(&arrow_schema);
|
||||
AssertInfo(st.ok(), "Failed to get arrow schema");
|
||||
AssertInfo(arrow_schema->num_fields() == 1,
|
||||
"Vector field should have exactly 1 field, got {}",
|
||||
arrow_schema->num_fields());
|
||||
|
||||
auto field = arrow_schema->field(0);
|
||||
if (field->HasMetadata()) {
|
||||
auto metadata = field->metadata();
|
||||
if (metadata->Contains(DIM_KEY)) {
|
||||
auto dim_str = metadata->Get(DIM_KEY).ValueOrDie();
|
||||
dim_ = std::stoi(dim_str);
|
||||
AssertInfo(
|
||||
dim_ > 0,
|
||||
"nullable vector dim must be positive, got {}",
|
||||
dim_);
|
||||
} else {
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
"nullable vector field metadata missing "
|
||||
"required 'dim' field");
|
||||
}
|
||||
} else {
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
"nullable vector field is missing metadata");
|
||||
}
|
||||
} else {
|
||||
dim_ = GetDimensionFromFileMetaData(
|
||||
file_meta->schema()->Column(column_index), column_type_);
|
||||
}
|
||||
} else {
|
||||
dim_ = 1;
|
||||
}
|
||||
|
||||
// For VectorArray, get element type and dim from Arrow schema metadata
|
||||
auto element_type = DataType::NONE;
|
||||
@ -133,8 +164,10 @@ PayloadReader::init(const uint8_t* data, int length, bool is_field_data) {
|
||||
field_data_->FillFieldData(array);
|
||||
}
|
||||
|
||||
AssertInfo(field_data_->IsFull(),
|
||||
"field data hasn't been filled done");
|
||||
if (!nullable_ || !IsVectorDataType(column_type_)) {
|
||||
AssertInfo(field_data_->IsFull(),
|
||||
"field data hasn't been filled done");
|
||||
}
|
||||
} else {
|
||||
arrow_reader_ = std::move(arrow_reader);
|
||||
record_batch_reader_ = std::move(rb_reader);
|
||||
|
||||
@ -35,7 +35,6 @@ PayloadWriter::PayloadWriter(const DataType column_type, int dim, bool nullable)
|
||||
AssertInfo(column_type != DataType::VECTOR_SPARSE_U32_F32,
|
||||
"PayloadWriter for Sparse Float Vector should be created "
|
||||
"using the constructor without dimension");
|
||||
AssertInfo(nullable == false, "only scalcar type support null now");
|
||||
init_dimension(dim);
|
||||
}
|
||||
|
||||
@ -63,7 +62,7 @@ PayloadWriter::init_dimension(int dim) {
|
||||
}
|
||||
|
||||
dimension_ = dim;
|
||||
builder_ = CreateArrowBuilder(column_type_, element_type_, dim);
|
||||
builder_ = CreateArrowBuilder(column_type_, element_type_, dim, nullable_);
|
||||
schema_ = CreateArrowSchema(column_type_, dim, nullable_);
|
||||
}
|
||||
|
||||
@ -112,8 +111,10 @@ PayloadWriter::finish() {
|
||||
|
||||
std::shared_ptr<parquet::ArrowWriterProperties> arrow_properties =
|
||||
parquet::default_arrow_writer_properties();
|
||||
if (column_type_ == DataType::VECTOR_ARRAY) {
|
||||
// For VectorArray, we need to store schema metadata
|
||||
if (column_type_ == DataType::VECTOR_ARRAY ||
|
||||
(nullable_ && IsVectorDataType(column_type_) &&
|
||||
!IsSparseFloatVectorDataType(column_type_))) {
|
||||
// For VectorArray and nullable vectors, we need to store schema metadata
|
||||
parquet::ArrowWriterProperties::Builder arrow_props_builder;
|
||||
arrow_props_builder.store_schema();
|
||||
arrow_properties = arrow_props_builder.build();
|
||||
|
||||
@ -128,13 +128,40 @@ ReadMediumType(BinlogReaderPtr reader) {
|
||||
void
|
||||
add_vector_payload(std::shared_ptr<arrow::ArrayBuilder> builder,
|
||||
uint8_t* values,
|
||||
int length) {
|
||||
const uint8_t* valid_data,
|
||||
bool nullable,
|
||||
int length,
|
||||
int byte_width) {
|
||||
AssertInfo(builder != nullptr, "empty arrow builder");
|
||||
auto binary_builder =
|
||||
std::dynamic_pointer_cast<arrow::FixedSizeBinaryBuilder>(builder);
|
||||
auto ast = binary_builder->AppendValues(values, length);
|
||||
AssertInfo(
|
||||
ast.ok(), "append value to arrow builder failed: {}", ast.ToString());
|
||||
AssertInfo((nullable && valid_data) || !nullable,
|
||||
"valid_data is required for nullable vectors");
|
||||
arrow::Status ast;
|
||||
|
||||
if (nullable) {
|
||||
auto binary_builder =
|
||||
std::dynamic_pointer_cast<arrow::BinaryBuilder>(builder);
|
||||
int valid_index = 0;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
auto bit = (valid_data[i >> 3] >> (i & 0x07)) & 1;
|
||||
if (bit) {
|
||||
ast = binary_builder->Append(values + valid_index * byte_width,
|
||||
byte_width);
|
||||
valid_index++;
|
||||
} else {
|
||||
ast = binary_builder->AppendNull();
|
||||
}
|
||||
AssertInfo(ast.ok(),
|
||||
"append value to arrow builder failed: {}",
|
||||
ast.ToString());
|
||||
}
|
||||
} else {
|
||||
auto binary_builder =
|
||||
std::dynamic_pointer_cast<arrow::FixedSizeBinaryBuilder>(builder);
|
||||
ast = binary_builder->AppendValues(values, length);
|
||||
AssertInfo(ast.ok(),
|
||||
"append value to arrow builder failed: {}",
|
||||
ast.ToString());
|
||||
}
|
||||
}
|
||||
|
||||
// append values for numeric data
|
||||
@ -223,12 +250,64 @@ AddPayloadToArrowBuilder(std::shared_ptr<arrow::ArrayBuilder> builder,
|
||||
break;
|
||||
}
|
||||
|
||||
case DataType::VECTOR_FLOAT16:
|
||||
case DataType::VECTOR_BFLOAT16:
|
||||
case DataType::VECTOR_BINARY:
|
||||
case DataType::VECTOR_INT8:
|
||||
case DataType::VECTOR_FLOAT: {
|
||||
add_vector_payload(builder, const_cast<uint8_t*>(raw_data), length);
|
||||
AssertInfo(payload.dimension.has_value(),
|
||||
"dimension is required for VECTOR_FLOAT");
|
||||
int byte_width = payload.dimension.value() * sizeof(float);
|
||||
add_vector_payload(builder,
|
||||
const_cast<uint8_t*>(raw_data),
|
||||
payload.valid_data,
|
||||
nullable,
|
||||
length,
|
||||
byte_width);
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_BINARY: {
|
||||
AssertInfo(payload.dimension.has_value(),
|
||||
"dimension is required for VECTOR_BINARY");
|
||||
int byte_width = (payload.dimension.value() + 7) / 8;
|
||||
add_vector_payload(builder,
|
||||
const_cast<uint8_t*>(raw_data),
|
||||
payload.valid_data,
|
||||
nullable,
|
||||
length,
|
||||
byte_width);
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_FLOAT16: {
|
||||
AssertInfo(payload.dimension.has_value(),
|
||||
"dimension is required for VECTOR_FLOAT16");
|
||||
int byte_width = payload.dimension.value() * 2;
|
||||
add_vector_payload(builder,
|
||||
const_cast<uint8_t*>(raw_data),
|
||||
payload.valid_data,
|
||||
nullable,
|
||||
length,
|
||||
byte_width);
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_BFLOAT16: {
|
||||
AssertInfo(payload.dimension.has_value(),
|
||||
"dimension is required for VECTOR_BFLOAT16");
|
||||
int byte_width = payload.dimension.value() * 2;
|
||||
add_vector_payload(builder,
|
||||
const_cast<uint8_t*>(raw_data),
|
||||
payload.valid_data,
|
||||
nullable,
|
||||
length,
|
||||
byte_width);
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_INT8: {
|
||||
AssertInfo(payload.dimension.has_value(),
|
||||
"dimension is required for VECTOR_INT8");
|
||||
int byte_width = payload.dimension.value() * sizeof(int8_t);
|
||||
add_vector_payload(builder,
|
||||
const_cast<uint8_t*>(raw_data),
|
||||
payload.valid_data,
|
||||
nullable,
|
||||
length,
|
||||
byte_width);
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_SPARSE_U32_F32: {
|
||||
@ -380,30 +459,48 @@ CreateArrowBuilder(DataType data_type) {
|
||||
}
|
||||
|
||||
std::shared_ptr<arrow::ArrayBuilder>
|
||||
CreateArrowBuilder(DataType data_type, DataType element_type, int dim) {
|
||||
CreateArrowBuilder(DataType data_type,
|
||||
DataType element_type,
|
||||
int dim,
|
||||
bool nullable) {
|
||||
switch (static_cast<DataType>(data_type)) {
|
||||
case DataType::VECTOR_FLOAT: {
|
||||
AssertInfo(dim > 0, "invalid dim value: {}", dim);
|
||||
if (nullable) {
|
||||
return std::make_shared<arrow::BinaryBuilder>();
|
||||
}
|
||||
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||
arrow::fixed_size_binary(dim * sizeof(float)));
|
||||
}
|
||||
case DataType::VECTOR_BINARY: {
|
||||
AssertInfo(dim % 8 == 0 && dim > 0, "invalid dim value: {}", dim);
|
||||
if (nullable) {
|
||||
return std::make_shared<arrow::BinaryBuilder>();
|
||||
}
|
||||
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||
arrow::fixed_size_binary(dim / 8));
|
||||
}
|
||||
case DataType::VECTOR_FLOAT16: {
|
||||
AssertInfo(dim > 0, "invalid dim value: {}", dim);
|
||||
if (nullable) {
|
||||
return std::make_shared<arrow::BinaryBuilder>();
|
||||
}
|
||||
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||
arrow::fixed_size_binary(dim * sizeof(float16)));
|
||||
}
|
||||
case DataType::VECTOR_BFLOAT16: {
|
||||
AssertInfo(dim > 0, "invalid dim value");
|
||||
if (nullable) {
|
||||
return std::make_shared<arrow::BinaryBuilder>();
|
||||
}
|
||||
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||
arrow::fixed_size_binary(dim * sizeof(bfloat16)));
|
||||
}
|
||||
case DataType::VECTOR_INT8: {
|
||||
AssertInfo(dim > 0, "invalid dim value");
|
||||
if (nullable) {
|
||||
return std::make_shared<arrow::BinaryBuilder>();
|
||||
}
|
||||
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||
arrow::fixed_size_binary(dim * sizeof(int8)));
|
||||
}
|
||||
@ -576,6 +673,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
|
||||
switch (static_cast<DataType>(data_type)) {
|
||||
case DataType::VECTOR_FLOAT: {
|
||||
AssertInfo(dim > 0, "invalid dim value: {}", dim);
|
||||
if (nullable) {
|
||||
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
|
||||
new arrow::KeyValueMetadata());
|
||||
metadata->Append(DIM_KEY, std::to_string(dim));
|
||||
return arrow::schema(
|
||||
{arrow::field("val", arrow::binary(), nullable, metadata)});
|
||||
}
|
||||
return arrow::schema(
|
||||
{arrow::field("val",
|
||||
arrow::fixed_size_binary(dim * sizeof(float)),
|
||||
@ -583,11 +687,25 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
|
||||
}
|
||||
case DataType::VECTOR_BINARY: {
|
||||
AssertInfo(dim % 8 == 0 && dim > 0, "invalid dim value: {}", dim);
|
||||
if (nullable) {
|
||||
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
|
||||
new arrow::KeyValueMetadata());
|
||||
metadata->Append(DIM_KEY, std::to_string(dim));
|
||||
return arrow::schema(
|
||||
{arrow::field("val", arrow::binary(), nullable, metadata)});
|
||||
}
|
||||
return arrow::schema({arrow::field(
|
||||
"val", arrow::fixed_size_binary(dim / 8), nullable)});
|
||||
}
|
||||
case DataType::VECTOR_FLOAT16: {
|
||||
AssertInfo(dim > 0, "invalid dim value: {}", dim);
|
||||
if (nullable) {
|
||||
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
|
||||
new arrow::KeyValueMetadata());
|
||||
metadata->Append(DIM_KEY, std::to_string(dim));
|
||||
return arrow::schema(
|
||||
{arrow::field("val", arrow::binary(), nullable, metadata)});
|
||||
}
|
||||
return arrow::schema(
|
||||
{arrow::field("val",
|
||||
arrow::fixed_size_binary(dim * sizeof(float16)),
|
||||
@ -595,6 +713,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
|
||||
}
|
||||
case DataType::VECTOR_BFLOAT16: {
|
||||
AssertInfo(dim > 0, "invalid dim value");
|
||||
if (nullable) {
|
||||
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
|
||||
new arrow::KeyValueMetadata());
|
||||
metadata->Append(DIM_KEY, std::to_string(dim));
|
||||
return arrow::schema(
|
||||
{arrow::field("val", arrow::binary(), nullable, metadata)});
|
||||
}
|
||||
return arrow::schema(
|
||||
{arrow::field("val",
|
||||
arrow::fixed_size_binary(dim * sizeof(bfloat16)),
|
||||
@ -606,6 +731,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
|
||||
}
|
||||
case DataType::VECTOR_INT8: {
|
||||
AssertInfo(dim > 0, "invalid dim value");
|
||||
if (nullable) {
|
||||
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
|
||||
new arrow::KeyValueMetadata());
|
||||
metadata->Append(DIM_KEY, std::to_string(dim));
|
||||
return arrow::schema(
|
||||
{arrow::field("val", arrow::binary(), nullable, metadata)});
|
||||
}
|
||||
return arrow::schema(
|
||||
{arrow::field("val",
|
||||
arrow::fixed_size_binary(dim * sizeof(int8)),
|
||||
@ -1103,22 +1235,22 @@ CreateFieldData(const DataType& type,
|
||||
type, nullable, total_num_rows);
|
||||
case DataType::VECTOR_FLOAT:
|
||||
return std::make_shared<FieldData<FloatVector>>(
|
||||
dim, type, total_num_rows);
|
||||
dim, type, nullable, total_num_rows);
|
||||
case DataType::VECTOR_BINARY:
|
||||
return std::make_shared<FieldData<BinaryVector>>(
|
||||
dim, type, total_num_rows);
|
||||
dim, type, nullable, total_num_rows);
|
||||
case DataType::VECTOR_FLOAT16:
|
||||
return std::make_shared<FieldData<Float16Vector>>(
|
||||
dim, type, total_num_rows);
|
||||
dim, type, nullable, total_num_rows);
|
||||
case DataType::VECTOR_BFLOAT16:
|
||||
return std::make_shared<FieldData<BFloat16Vector>>(
|
||||
dim, type, total_num_rows);
|
||||
dim, type, nullable, total_num_rows);
|
||||
case DataType::VECTOR_SPARSE_U32_F32:
|
||||
return std::make_shared<FieldData<SparseFloatVector>>(
|
||||
type, total_num_rows);
|
||||
type, nullable, total_num_rows);
|
||||
case DataType::VECTOR_INT8:
|
||||
return std::make_shared<FieldData<Int8Vector>>(
|
||||
dim, type, total_num_rows);
|
||||
dim, type, nullable, total_num_rows);
|
||||
case DataType::VECTOR_ARRAY:
|
||||
return std::make_shared<FieldData<VectorArray>>(
|
||||
dim, element_type, total_num_rows);
|
||||
|
||||
@ -59,7 +59,10 @@ std::shared_ptr<arrow::ArrayBuilder>
|
||||
CreateArrowBuilder(DataType data_type);
|
||||
|
||||
std::shared_ptr<arrow::ArrayBuilder>
|
||||
CreateArrowBuilder(DataType data_type, DataType element_type, int dim);
|
||||
CreateArrowBuilder(DataType data_type,
|
||||
DataType element_type,
|
||||
int dim,
|
||||
bool nullable = false);
|
||||
|
||||
/// \brief Utility function to create arrow:Scalar from FieldMeta.default_value
|
||||
///
|
||||
|
||||
@ -326,6 +326,10 @@ GenerateRandomSparseFloatVector(size_t rows,
|
||||
size_t cols = kTestSparseDim,
|
||||
float density = kTestSparseVectorDensity,
|
||||
int seed = 42) {
|
||||
if (rows == 0) {
|
||||
return std::make_unique<
|
||||
knowhere::sparse::SparseRow<milvus::SparseValueType>[]>(0);
|
||||
}
|
||||
int32_t num_elements = static_cast<int32_t>(rows * cols * density);
|
||||
|
||||
std::mt19937 rng(seed);
|
||||
@ -542,7 +546,8 @@ DataGen(SchemaPtr schema,
|
||||
int group_count = 1,
|
||||
bool random_pk = false,
|
||||
bool random_val = true,
|
||||
bool random_valid = false) {
|
||||
bool random_valid = false,
|
||||
int null_percent = 50) {
|
||||
using std::vector;
|
||||
std::default_random_engine random(seed);
|
||||
std::normal_distribution<> distr(0, 1);
|
||||
@ -635,41 +640,138 @@ DataGen(SchemaPtr schema,
|
||||
return data;
|
||||
};
|
||||
|
||||
auto generate_valid_data = [&](const FieldMeta& field_meta, int64_t N) {
|
||||
struct Result {
|
||||
int64_t valid_count;
|
||||
FixedVector<bool> valid_data;
|
||||
};
|
||||
|
||||
Result result;
|
||||
result.valid_data.resize(N);
|
||||
result.valid_count = 0;
|
||||
|
||||
bool is_nullable = field_meta.is_nullable();
|
||||
if (is_nullable) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
if (random_valid) {
|
||||
int x = rand();
|
||||
result.valid_data[i] = x % 2 == 0 ? true : false;
|
||||
} else {
|
||||
result.valid_data[i] = (i % 100) >= null_percent;
|
||||
}
|
||||
if (result.valid_data[i]) {
|
||||
result.valid_count++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.valid_count = N;
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
for (auto field_id : schema->get_field_ids()) {
|
||||
auto field_meta = schema->operator[](field_id);
|
||||
switch (field_meta.get_data_type()) {
|
||||
case DataType::VECTOR_FLOAT: {
|
||||
auto data = generate_float_vector(field_meta, N);
|
||||
insert_cols(data, N, field_meta, random_valid);
|
||||
auto [valid_count, valid_data] =
|
||||
generate_valid_data(field_meta, N);
|
||||
bool is_nullable = field_meta.is_nullable();
|
||||
|
||||
auto data = generate_float_vector(field_meta, valid_count);
|
||||
auto array = milvus::segcore::CreateVectorDataArrayFrom(
|
||||
data.data(),
|
||||
is_nullable ? valid_data.data() : nullptr,
|
||||
N,
|
||||
valid_count,
|
||||
field_meta);
|
||||
insert_data->mutable_fields_data()->AddAllocated(
|
||||
array.release());
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_BINARY: {
|
||||
auto data = generate_binary_vector(field_meta, N);
|
||||
insert_cols(data, N, field_meta, random_valid);
|
||||
auto [valid_count, valid_data] =
|
||||
generate_valid_data(field_meta, N);
|
||||
bool is_nullable = field_meta.is_nullable();
|
||||
|
||||
auto data = generate_binary_vector(field_meta, valid_count);
|
||||
auto array = milvus::segcore::CreateVectorDataArrayFrom(
|
||||
data.data(),
|
||||
is_nullable ? valid_data.data() : nullptr,
|
||||
N,
|
||||
valid_count,
|
||||
field_meta);
|
||||
insert_data->mutable_fields_data()->AddAllocated(
|
||||
array.release());
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_FLOAT16: {
|
||||
auto data = generate_float16_vector(field_meta, N);
|
||||
insert_cols(data, N, field_meta, random_valid);
|
||||
auto [valid_count, valid_data] =
|
||||
generate_valid_data(field_meta, N);
|
||||
bool is_nullable = field_meta.is_nullable();
|
||||
|
||||
auto data = generate_float16_vector(field_meta, valid_count);
|
||||
auto array = milvus::segcore::CreateVectorDataArrayFrom(
|
||||
data.data(),
|
||||
is_nullable ? valid_data.data() : nullptr,
|
||||
N,
|
||||
valid_count,
|
||||
field_meta);
|
||||
insert_data->mutable_fields_data()->AddAllocated(
|
||||
array.release());
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_BFLOAT16: {
|
||||
auto data = generate_bfloat16_vector(field_meta, N);
|
||||
insert_cols(data, N, field_meta, random_valid);
|
||||
auto [valid_count, valid_data] =
|
||||
generate_valid_data(field_meta, N);
|
||||
bool is_nullable = field_meta.is_nullable();
|
||||
|
||||
auto data = generate_bfloat16_vector(field_meta, valid_count);
|
||||
auto array = milvus::segcore::CreateVectorDataArrayFrom(
|
||||
data.data(),
|
||||
is_nullable ? valid_data.data() : nullptr,
|
||||
N,
|
||||
valid_count,
|
||||
field_meta);
|
||||
insert_data->mutable_fields_data()->AddAllocated(
|
||||
array.release());
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_SPARSE_U32_F32: {
|
||||
auto res = GenerateRandomSparseFloatVector(
|
||||
N, kTestSparseDim, kTestSparseVectorDensity, seed);
|
||||
auto array = milvus::segcore::CreateDataArrayFrom(
|
||||
res.get(), nullptr, N, field_meta);
|
||||
auto [valid_count, valid_data] =
|
||||
generate_valid_data(field_meta, N);
|
||||
bool is_nullable = field_meta.is_nullable();
|
||||
|
||||
auto res =
|
||||
GenerateRandomSparseFloatVector(valid_count,
|
||||
kTestSparseDim,
|
||||
kTestSparseVectorDensity,
|
||||
seed);
|
||||
|
||||
auto array = milvus::segcore::CreateVectorDataArrayFrom(
|
||||
res.get(),
|
||||
is_nullable ? valid_data.data() : nullptr,
|
||||
N,
|
||||
valid_count,
|
||||
field_meta);
|
||||
insert_data->mutable_fields_data()->AddAllocated(
|
||||
array.release());
|
||||
break;
|
||||
}
|
||||
case DataType::VECTOR_INT8: {
|
||||
auto data = generate_int8_vector(field_meta, N);
|
||||
insert_cols(data, N, field_meta, random_valid);
|
||||
auto [valid_count, valid_data] =
|
||||
generate_valid_data(field_meta, N);
|
||||
bool is_nullable = field_meta.is_nullable();
|
||||
|
||||
auto data = generate_int8_vector(field_meta, valid_count);
|
||||
auto array = milvus::segcore::CreateVectorDataArrayFrom(
|
||||
data.data(),
|
||||
is_nullable ? valid_data.data() : nullptr,
|
||||
N,
|
||||
valid_count,
|
||||
field_meta);
|
||||
insert_data->mutable_fields_data()->AddAllocated(
|
||||
array.release());
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@ -157,13 +157,13 @@ PrepareSingleFieldInsertBinlog(int64_t collection_id,
|
||||
int64_t row_count = 0;
|
||||
for (auto i = 0; i < field_datas.size(); ++i) {
|
||||
auto& field_data = field_datas[i];
|
||||
row_count += field_data->Length();
|
||||
row_count += field_data->get_num_rows();
|
||||
auto file = "./data/test/" + std::to_string(collection_id) + "/" +
|
||||
std::to_string(partition_id) + "/" +
|
||||
std::to_string(segment_id) + "/" +
|
||||
std::to_string(field_id) + "/" + std::to_string(i);
|
||||
files.push_back(file);
|
||||
row_counts.push_back(field_data->Length());
|
||||
row_counts.push_back(field_data->get_num_rows());
|
||||
auto payload_reader =
|
||||
std::make_shared<milvus::storage::PayloadReader>(field_data);
|
||||
auto insert_data = std::make_shared<InsertData>(payload_reader);
|
||||
|
||||
@ -141,7 +141,7 @@ func (s *InsertBufferSuite) TestBuffer() {
|
||||
memSize := insertBuffer.Buffer(groups[0], &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200})
|
||||
|
||||
s.EqualValues(100, insertBuffer.MinTimestamp())
|
||||
s.EqualValues(5367, memSize)
|
||||
s.EqualValues(5376, memSize)
|
||||
}
|
||||
|
||||
func (s *InsertBufferSuite) TestYield() {
|
||||
|
||||
@ -195,12 +195,12 @@ func (s *L0WriteBufferSuite) TestBufferData() {
|
||||
|
||||
value, err := metrics.DataNodeFlowGraphBufferDataSize.GetMetricWithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(s.metacache.Collection()))
|
||||
s.NoError(err)
|
||||
s.MetricsEqual(value, 5607)
|
||||
s.MetricsEqual(value, 5616)
|
||||
|
||||
delMsg = s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) }))
|
||||
err = wb.BufferData([]*InsertData{}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200})
|
||||
s.NoError(err)
|
||||
s.MetricsEqual(value, 5847)
|
||||
s.MetricsEqual(value, 5856)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -65,11 +65,15 @@ func genInsertMsgsByPartition(ctx context.Context,
|
||||
return msg
|
||||
}
|
||||
|
||||
fieldsData := insertMsg.GetFieldsData()
|
||||
idxComputer := typeutil.NewFieldDataIdxComputer(fieldsData)
|
||||
|
||||
repackedMsgs := make([]msgstream.TsMsg, 0)
|
||||
requestSize := 0
|
||||
msg := createInsertMsg(segmentID, channelName)
|
||||
for _, offset := range rowOffsets {
|
||||
curRowMessageSize, err := typeutil.EstimateEntitySize(insertMsg.GetFieldsData(), offset)
|
||||
fieldIdxs := idxComputer.Compute(int64(offset))
|
||||
curRowMessageSize, err := typeutil.EstimateEntitySize(fieldsData, offset, fieldIdxs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -81,7 +85,7 @@ func genInsertMsgsByPartition(ctx context.Context,
|
||||
requestSize = 0
|
||||
}
|
||||
|
||||
typeutil.AppendFieldData(msg.FieldsData, insertMsg.GetFieldsData(), int64(offset))
|
||||
typeutil.AppendFieldData(msg.FieldsData, fieldsData, int64(offset), fieldIdxs...)
|
||||
msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset])
|
||||
msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset])
|
||||
msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset])
|
||||
|
||||
@ -258,6 +258,11 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
||||
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
|
||||
}
|
||||
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum)
|
||||
for i, srd := range subSearchResultData {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
|
||||
}
|
||||
|
||||
var realTopK int64 = -1
|
||||
var retSize int64
|
||||
|
||||
@ -316,7 +321,8 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
||||
for _, groupEntity := range groupEntities {
|
||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||
if len(ret.Results.FieldsData) > 0 {
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||
fieldIdxs := idxComputers[groupEntity.subSearchIdx].Compute(groupEntity.resultIdx)
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx, fieldIdxs...)
|
||||
}
|
||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||
@ -424,6 +430,12 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||
}
|
||||
}
|
||||
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum)
|
||||
for i, srd := range subSearchResultData {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
|
||||
}
|
||||
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
// reducing nq * topk results
|
||||
for i := int64(0); i < nq; i++ {
|
||||
@ -456,7 +468,9 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
||||
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
|
||||
|
||||
if len(ret.Results.FieldsData) > 0 {
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
|
||||
fieldsData := subSearchResultData[subSearchIdx].FieldsData
|
||||
fieldIdxs := idxComputers[subSearchIdx].Compute(resultDataIdx)
|
||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, fieldsData, resultDataIdx, fieldIdxs...)
|
||||
}
|
||||
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
|
||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||
|
||||
@ -563,9 +563,6 @@ func (t *addCollectionFieldTask) PreExecute(ctx context.Context) error {
|
||||
return merr.WrapErrParameterInvalid("valid field", fmt.Sprintf("field data type: %s is not supported", t.fieldSchema.GetDataType()))
|
||||
}
|
||||
|
||||
if typeutil.IsVectorType(t.fieldSchema.DataType) {
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("not support to add vector field, field name = %s", t.fieldSchema.Name))
|
||||
}
|
||||
if funcutil.SliceContain([]string{common.RowIDFieldName, common.TimeStampFieldName, common.MetaFieldName, common.NamespaceFieldName}, t.fieldSchema.GetName()) {
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("not support to add system field, field name = %s", t.fieldSchema.Name))
|
||||
}
|
||||
@ -575,6 +572,17 @@ func (t *addCollectionFieldTask) PreExecute(ctx context.Context) error {
|
||||
if !t.fieldSchema.Nullable {
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("added field must be nullable, please check it, field name = %s", t.fieldSchema.Name))
|
||||
}
|
||||
if typeutil.IsVectorType(t.fieldSchema.DataType) && t.fieldSchema.Nullable {
|
||||
if t.fieldSchema.DataType == schemapb.DataType_FloatVector ||
|
||||
t.fieldSchema.DataType == schemapb.DataType_Float16Vector ||
|
||||
t.fieldSchema.DataType == schemapb.DataType_BFloat16Vector ||
|
||||
t.fieldSchema.DataType == schemapb.DataType_BinaryVector ||
|
||||
t.fieldSchema.DataType == schemapb.DataType_Int8Vector {
|
||||
if len(t.fieldSchema.TypeParams) == 0 {
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("vector field must have dimension specified, field name = %s", t.fieldSchema.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
if t.fieldSchema.AutoID {
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("only primary field can speficy AutoID with true, field name = %s", t.fieldSchema.Name))
|
||||
}
|
||||
|
||||
@ -790,6 +790,10 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
||||
}
|
||||
|
||||
cursors := make([]int64, len(validRetrieveResults))
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults))
|
||||
for i, vr := range validRetrieveResults {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.GetFieldsData())
|
||||
}
|
||||
|
||||
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
|
||||
// IReduceInOrderForBest will try to get as many results as possible
|
||||
@ -819,7 +823,8 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
|
||||
if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
|
||||
break
|
||||
}
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
|
||||
fieldIdxs := idxComputers[sel].Compute(cursors[sel])
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel], fieldIdxs...)
|
||||
|
||||
// limit retrieve result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
|
||||
@ -1009,16 +1009,30 @@ func TestAddFieldTask(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
|
||||
// not support vector field
|
||||
fSchema = &schemapb.FieldSchema{
|
||||
Name: "vec_field",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "128"},
|
||||
},
|
||||
}
|
||||
bytes, err = proto.Marshal(fSchema)
|
||||
assert.NoError(t, err)
|
||||
task.Schema = bytes
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fSchema = &schemapb.FieldSchema{
|
||||
Name: "sparse_vec",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
Nullable: true,
|
||||
}
|
||||
bytes, err = proto.Marshal(fSchema)
|
||||
assert.NoError(t, err)
|
||||
task.Schema = bytes
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// not support system field
|
||||
fSchema = &schemapb.FieldSchema{
|
||||
@ -2595,11 +2609,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("upsert", func(t *testing.T) {
|
||||
// upsert require pk unique in same batch
|
||||
hash := make([]uint32, nb)
|
||||
for i := 0; i < nb; i++ {
|
||||
hash[i] = uint32(i)
|
||||
}
|
||||
hash := testutils.GenerateHashKeys(nb)
|
||||
task := &upsertTask{
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &BaseInsertTask{
|
||||
|
||||
@ -392,6 +392,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
}
|
||||
|
||||
baseIdx := 0
|
||||
idxComputer := typeutil.NewFieldDataIdxComputer(existFieldData)
|
||||
for _, idx := range updateIdxInUpsert {
|
||||
typeutil.AppendIDs(it.deletePKs, upsertIDs, idx)
|
||||
oldPK := typeutil.GetPK(upsertIDs, int64(idx))
|
||||
@ -399,7 +400,8 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
|
||||
}
|
||||
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex))
|
||||
fieldIdxs := idxComputer.Compute(int64(existIndex))
|
||||
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex), fieldIdxs...)
|
||||
err := typeutil.UpdateFieldData(it.insertFieldData, it.upsertMsg.InsertMsg.GetFieldsData(), int64(baseIdx), int64(idx))
|
||||
baseIdx += 1
|
||||
if err != nil {
|
||||
@ -438,8 +440,32 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
insertWithNullField = append(insertWithNullField, fieldData)
|
||||
}
|
||||
}
|
||||
for _, idx := range insertIdxInUpsert {
|
||||
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx))
|
||||
vectorIdxMap := make([][]int64, len(insertIdxInUpsert))
|
||||
for rowIdx, offset := range insertIdxInUpsert {
|
||||
vectorIdxMap[rowIdx] = make([]int64, len(insertWithNullField))
|
||||
for fieldIdx := range insertWithNullField {
|
||||
vectorIdxMap[rowIdx][fieldIdx] = int64(offset)
|
||||
}
|
||||
}
|
||||
for fieldIdx, fieldData := range insertWithNullField {
|
||||
validData := fieldData.GetValidData()
|
||||
if len(validData) > 0 && typeutil.IsVectorType(fieldData.Type) {
|
||||
dataIdx := int64(0)
|
||||
rowIdx := 0
|
||||
for i := 0; i < len(validData) && rowIdx < len(insertIdxInUpsert); i++ {
|
||||
if i == insertIdxInUpsert[rowIdx] {
|
||||
vectorIdxMap[rowIdx][fieldIdx] = dataIdx
|
||||
rowIdx++
|
||||
}
|
||||
if validData[i] {
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for rowIdx, idx := range insertIdxInUpsert {
|
||||
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx), vectorIdxMap[rowIdx]...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -620,6 +646,10 @@ func ToCompressedFormatNullable(field *schemapb.FieldData) error {
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
||||
}
|
||||
|
||||
case *schemapb.FieldData_Vectors:
|
||||
// Vector data is already in compressed format, skip
|
||||
return nil
|
||||
|
||||
default:
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
||||
}
|
||||
@ -1077,7 +1107,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// deduplicate upsert data to handle duplicate primary keys in the same batch
|
||||
// check for duplicate primary keys in the same batch
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Warn("fail to get primary field schema", zap.Error(err))
|
||||
|
||||
@ -141,7 +141,7 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, helper *typeutil.Sch
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
if err := v.checkSparseFloatFieldData(field, fieldSchema); err != nil {
|
||||
if err := v.checkSparseFloatVectorFieldData(field, fieldSchema); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_Int8Vector:
|
||||
@ -219,6 +219,13 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
msg := fmt.Sprintf("the dim (%d) of field data(%s) is not equal to schema dim (%d)", dataDim, fieldName, schemaDim)
|
||||
return merr.WrapErrParameterInvalid(schemaDim, dataDim, msg)
|
||||
}
|
||||
getExpectedVectorRows := func(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) uint64 {
|
||||
validData := field.GetValidData()
|
||||
if fieldSchema.GetNullable() && len(validData) > 0 {
|
||||
return uint64(getValidNumber(validData))
|
||||
}
|
||||
return numRows
|
||||
}
|
||||
for _, field := range data {
|
||||
switch field.GetType() {
|
||||
case schemapb.DataType_FloatVector:
|
||||
@ -241,7 +248,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return errDimMismatch(field.GetFieldName(), dataDim, dim)
|
||||
}
|
||||
|
||||
if n != numRows {
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
|
||||
@ -265,7 +273,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
if n != numRows {
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
|
||||
@ -289,7 +298,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
if n != numRows {
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
|
||||
@ -313,13 +323,19 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
if n != numRows {
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
f, err := schema.GetFieldFromName(field.GetFieldName())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents))
|
||||
if n != numRows {
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
|
||||
@ -343,7 +359,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return errDimMismatch(field.GetFieldName(), dataDim, dim)
|
||||
}
|
||||
|
||||
if n != numRows {
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
|
||||
@ -728,7 +745,7 @@ func getValidNumber(validData []bool) int {
|
||||
|
||||
func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
floatArray := field.GetVectors().GetFloatVector().GetData()
|
||||
if floatArray == nil {
|
||||
if floatArray == nil && !fieldSchema.GetNullable() {
|
||||
msg := fmt.Sprintf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need float vector", "got nil", msg)
|
||||
}
|
||||
@ -743,8 +760,11 @@ func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fiel
|
||||
func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
float16VecArray := field.GetVectors().GetFloat16Vector()
|
||||
if float16VecArray == nil {
|
||||
msg := fmt.Sprintf("float16 float field '%v' is illegal, nil Vector_Float16 type", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need vector_float16 array", "got nil", msg)
|
||||
if !fieldSchema.GetNullable() {
|
||||
msg := fmt.Sprintf("float16 vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need float16 vector", "got nil", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if v.checkNAN {
|
||||
return typeutil.VerifyFloats16(float16VecArray)
|
||||
@ -755,8 +775,11 @@ func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fi
|
||||
func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
bfloat16VecArray := field.GetVectors().GetBfloat16Vector()
|
||||
if bfloat16VecArray == nil {
|
||||
msg := fmt.Sprintf("bfloat16 float field '%v' is illegal, nil Vector_BFloat16 type", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need vector_bfloat16 array", "got nil", msg)
|
||||
if !fieldSchema.GetNullable() {
|
||||
msg := fmt.Sprintf("bfloat16 vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need bfloat16 vector", "got nil", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if v.checkNAN {
|
||||
return typeutil.VerifyBFloats16(bfloat16VecArray)
|
||||
@ -766,31 +789,33 @@ func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, f
|
||||
|
||||
func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
bVecArray := field.GetVectors().GetBinaryVector()
|
||||
if bVecArray == nil {
|
||||
msg := fmt.Sprintf("binary float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need bytes array", "got nil", msg)
|
||||
if bVecArray == nil && !fieldSchema.GetNullable() {
|
||||
msg := fmt.Sprintf("binary vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need binary vector", "got nil", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *validateUtil) checkSparseFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
func (v *validateUtil) checkSparseFloatVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
if field.GetVectors() == nil || field.GetVectors().GetSparseFloatVector() == nil {
|
||||
msg := fmt.Sprintf("sparse float field '%v' is illegal, nil SparseFloatVector", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
|
||||
if !fieldSchema.GetNullable() {
|
||||
msg := fmt.Sprintf("sparse float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need sparse float vector", "got nil", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
sparseRows := field.GetVectors().GetSparseFloatVector().GetContents()
|
||||
if sparseRows == nil {
|
||||
msg := fmt.Sprintf("sparse float field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
|
||||
}
|
||||
return typeutil.ValidateSparseFloatRows(sparseRows...)
|
||||
}
|
||||
|
||||
func (v *validateUtil) checkInt8VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
int8VecArray := field.GetVectors().GetInt8Vector()
|
||||
if int8VecArray == nil {
|
||||
msg := fmt.Sprintf("int8 vector field '%v' is illegal, nil Vector_Int8 type", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need vector_int8 array", "got nil", msg)
|
||||
if !fieldSchema.GetNullable() {
|
||||
msg := fmt.Sprintf("int8 vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||
return merr.WrapErrParameterInvalid("need int8 vector", "got nil", msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -310,23 +310,77 @@ func Test_validateUtil_checkTextFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_validateUtil_checkBinaryVectorFieldData(t *testing.T) {
|
||||
v := newValidateUtil()
|
||||
assert.Error(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil))
|
||||
assert.NoError(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: []byte(strings.Repeat("1", 128)),
|
||||
t.Run("not binary vector", func(t *testing.T) {
|
||||
v := newValidateUtil()
|
||||
err := v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
v := newValidateUtil()
|
||||
err := v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: []byte(strings.Repeat("1", 128)),
|
||||
},
|
||||
},
|
||||
},
|
||||
}}, nil))
|
||||
}}, nil)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector not nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
Nullable: false,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkBinaryVectorFieldData(data, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
Nullable: true,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkBinaryVectorFieldData(data, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
|
||||
nb := 5
|
||||
dim := int64(8)
|
||||
data := testutils.GenerateFloatVectors(nb, int(dim))
|
||||
invalidData := testutils.GenerateFloatVectorsWithInvalidData(nb, int(dim))
|
||||
|
||||
t.Run("not float vector", func(t *testing.T) {
|
||||
f := &schemapb.FieldData{}
|
||||
v := newValidateUtil()
|
||||
err := v.checkFloatVectorFieldData(f, nil)
|
||||
err := v.checkFloatVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@ -336,7 +390,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{1.1, 2.2},
|
||||
Data: invalidData,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -354,7 +408,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{float32(math.NaN())},
|
||||
Data: invalidData,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -371,7 +425,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{1.1, 2.2},
|
||||
Data: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -409,6 +463,49 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
|
||||
err = v.fillWithValue(data, h, 1)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector not nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
Nullable: false,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkFloatVectorFieldData(data, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
Nullable: true,
|
||||
}
|
||||
|
||||
v := newValidateUtil()
|
||||
err := v.checkFloatVectorFieldData(data, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) {
|
||||
@ -418,9 +515,8 @@ func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) {
|
||||
invalidData := testutils.GenerateFloat16VectorsWithInvalidData(nb, int(dim))
|
||||
|
||||
t.Run("not float16 vector", func(t *testing.T) {
|
||||
f := &schemapb.FieldData{}
|
||||
v := newValidateUtil()
|
||||
err := v.checkFloat16VectorFieldData(f, nil)
|
||||
err := v.checkFloat16VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@ -500,17 +596,60 @@ func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) {
|
||||
err = v.fillWithValue(data, h, 1)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector not nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Float16Vector{
|
||||
Float16Vector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_Float16Vector,
|
||||
Nullable: false,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkFloat16VectorFieldData(data, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Float16Vector{
|
||||
Float16Vector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_Float16Vector,
|
||||
Nullable: true,
|
||||
}
|
||||
|
||||
v := newValidateUtil()
|
||||
err := v.checkFloat16VectorFieldData(data, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_validateUtil_checkBfloatVectorFieldData(t *testing.T) {
|
||||
func Test_validateUtil_checkBFloat16VectorFieldData(t *testing.T) {
|
||||
nb := 5
|
||||
dim := int64(8)
|
||||
data := testutils.GenerateFloat16Vectors(nb, int(dim))
|
||||
data := testutils.GenerateBFloat16Vectors(nb, int(dim))
|
||||
invalidData := testutils.GenerateBFloat16VectorsWithInvalidData(nb, int(dim))
|
||||
t.Run("not float vector", func(t *testing.T) {
|
||||
f := &schemapb.FieldData{}
|
||||
|
||||
t.Run("not bfloat16 vector", func(t *testing.T) {
|
||||
v := newValidateUtil()
|
||||
err := v.checkBFloat16VectorFieldData(f, nil)
|
||||
err := v.checkBFloat16VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@ -590,6 +729,203 @@ func Test_validateUtil_checkBfloatVectorFieldData(t *testing.T) {
|
||||
err = v.fillWithValue(data, h, 1)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector not nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Bfloat16Vector{
|
||||
Bfloat16Vector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_BFloat16Vector,
|
||||
Nullable: false,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkBFloat16VectorFieldData(data, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Bfloat16Vector{
|
||||
Bfloat16Vector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_BFloat16Vector,
|
||||
Nullable: true,
|
||||
}
|
||||
|
||||
v := newValidateUtil()
|
||||
err := v.checkBFloat16VectorFieldData(data, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_validateUtil_checkSparseFloatVectorFieldData(t *testing.T) {
|
||||
nb := 5
|
||||
sparseContents, dim := testutils.GenerateSparseFloatVectorsData(nb)
|
||||
|
||||
t.Run("not sparse float vector", func(t *testing.T) {
|
||||
v := newValidateUtil()
|
||||
err := v.checkSparseFloatVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
fieldData := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: &schemapb.SparseFloatArray{
|
||||
Contents: sparseContents,
|
||||
Dim: dim,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkSparseFloatVectorFieldData(fieldData, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector not nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
Nullable: false,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkSparseFloatVectorFieldData(data, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector nullable", func(t *testing.T) {
|
||||
data := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
Nullable: true,
|
||||
}
|
||||
|
||||
v := newValidateUtil()
|
||||
err := v.checkSparseFloatVectorFieldData(data, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_validateUtil_checkInt8VectorFieldData(t *testing.T) {
|
||||
nb := 5
|
||||
dim := int64(8)
|
||||
data := typeutil.Int8ArrayToBytes(testutils.GenerateInt8Vectors(nb, int(dim)))
|
||||
|
||||
t.Run("not int8 vector", func(t *testing.T) {
|
||||
v := newValidateUtil()
|
||||
err := v.checkInt8VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
fieldData := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_Int8Vector{
|
||||
Int8Vector: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_Int8Vector,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkInt8VectorFieldData(fieldData, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector not nullable", func(t *testing.T) {
|
||||
fieldData := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Int8Vector{
|
||||
Int8Vector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_Int8Vector,
|
||||
Nullable: false,
|
||||
}
|
||||
v := newValidateUtil()
|
||||
err := v.checkInt8VectorFieldData(fieldData, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil vector nullable", func(t *testing.T) {
|
||||
fieldData := &schemapb.FieldData{
|
||||
FieldName: "vec",
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_Int8Vector{
|
||||
Int8Vector: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := &schemapb.FieldSchema{
|
||||
Name: "vec",
|
||||
DataType: schemapb.DataType_Int8Vector,
|
||||
Nullable: true,
|
||||
}
|
||||
|
||||
v := newValidateUtil()
|
||||
err := v.checkInt8VectorFieldData(fieldData, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_validateUtil_checkAligned(t *testing.T) {
|
||||
|
||||
@ -325,6 +325,11 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
|
||||
idTsMap := make(map[interface{}]int64)
|
||||
cursors := make([]int64, len(validRetrieveResults))
|
||||
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults))
|
||||
for i, vr := range validRetrieveResults {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.Result.GetFieldsData())
|
||||
}
|
||||
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
for j := 0; j < loopEnd; {
|
||||
@ -335,9 +340,11 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
|
||||
|
||||
pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
|
||||
ts := validRetrieveResults[sel].Timestamps[cursors[sel]]
|
||||
fieldsData := validRetrieveResults[sel].Result.GetFieldsData()
|
||||
fieldIdxs := idxComputers[sel].Compute(cursors[sel])
|
||||
if _, ok := idTsMap[pk]; !ok {
|
||||
typeutil.AppendPKs(ret.Ids, pk)
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel])
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, cursors[sel], fieldIdxs...)
|
||||
idTsMap[pk] = ts
|
||||
j++
|
||||
} else {
|
||||
@ -346,7 +353,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
|
||||
if ts != 0 && ts > idTsMap[pk] {
|
||||
idTsMap[pk] = ts
|
||||
typeutil.DeleteFieldData(ret.FieldsData)
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel])
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, cursors[sel], fieldIdxs...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -511,10 +518,17 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
|
||||
_, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
|
||||
defer span2.End()
|
||||
ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].Result.GetFieldsData(), int64(len(selections)))
|
||||
// cursors = make([]int64, len(validRetrieveResults))
|
||||
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults))
|
||||
for i, vr := range validRetrieveResults {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.Result.GetFieldsData())
|
||||
}
|
||||
|
||||
for _, selection := range selections {
|
||||
// cannot use `cursors[sel]` directly, since some of them may be skipped.
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[selection.batchIndex].Result.GetFieldsData(), selection.resultIndex)
|
||||
fieldsData := validRetrieveResults[selection.batchIndex].Result.GetFieldsData()
|
||||
fieldIdxs := idxComputers[selection.batchIndex].Compute(selection.resultIndex)
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, selection.resultIndex, fieldIdxs...)
|
||||
|
||||
// limit retrieve result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
@ -564,10 +578,18 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
|
||||
|
||||
_, span3 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
|
||||
defer span3.End()
|
||||
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(segmentResults))
|
||||
for i, r := range segmentResults {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(r.GetFieldsData())
|
||||
}
|
||||
|
||||
// retrieve result is compacted, use 0,1,2...end
|
||||
segmentResOffset := make([]int64, len(segmentResults))
|
||||
for _, selection := range selections {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, segmentResults[selection.batchIndex].GetFieldsData(), segmentResOffset[selection.batchIndex])
|
||||
fieldsData := segmentResults[selection.batchIndex].GetFieldsData()
|
||||
fieldIdxs := idxComputers[selection.batchIndex].Compute(segmentResOffset[selection.batchIndex])
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, segmentResOffset[selection.batchIndex], fieldIdxs...)
|
||||
segmentResOffset[selection.batchIndex]++
|
||||
// limit retrieve result to avoid oom
|
||||
if retSize > maxOutputSize {
|
||||
|
||||
@ -68,6 +68,11 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||
}
|
||||
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(searchResultData))
|
||||
for i, srd := range searchResultData {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
|
||||
}
|
||||
|
||||
var skipDupCnt int64
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
@ -87,7 +92,9 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
||||
|
||||
// remove duplicates
|
||||
if _, ok := idSet[id]; !ok {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
fieldsData := searchResultData[sel].FieldsData
|
||||
fieldIdxs := idxComputers[sel].Compute(idx)
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, idx, fieldIdxs...)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {
|
||||
@ -173,6 +180,11 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
||||
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
|
||||
}
|
||||
|
||||
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(searchResultData))
|
||||
for i, srd := range searchResultData {
|
||||
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
|
||||
}
|
||||
|
||||
var filteredCount int64
|
||||
var retSize int64
|
||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||
@ -208,7 +220,9 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
||||
// exceed the limit for each group, filter this entity
|
||||
filteredCount++
|
||||
} else {
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||
fieldsData := searchResultData[sel].FieldsData
|
||||
fieldIdxs := idxComputers[sel].Compute(idx)
|
||||
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, idx, fieldIdxs...)
|
||||
typeutil.AppendPKs(ret.Ids, id)
|
||||
ret.Scores = append(ret.Scores, score)
|
||||
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {
|
||||
|
||||
@ -667,14 +667,14 @@ func Test_createCollectionTask_validateSchema(t *testing.T) {
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := task.validateSchema(context.TODO(), schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "vector type not support null")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("struct array field - field with default value", func(t *testing.T) {
|
||||
@ -980,7 +980,7 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("vector type not support null", func(t *testing.T) {
|
||||
t.Run("vector type with nullable", func(t *testing.T) {
|
||||
collectionName := funcutil.GenRandomStr()
|
||||
field1 := funcutil.GenRandomStr()
|
||||
schema := &schemapb.CollectionSchema{
|
||||
@ -989,9 +989,17 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) {
|
||||
AutoID: false,
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: field1,
|
||||
DataType: 101,
|
||||
Nullable: true,
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: field1,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -1005,7 +1013,7 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) {
|
||||
},
|
||||
}
|
||||
err := task.prepareSchema(context.TODO())
|
||||
assert.Error(t, err)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -374,10 +374,6 @@ func checkFieldSchema(fieldSchemas []*schemapb.FieldSchema) error {
|
||||
msg := fmt.Sprintf("ArrayOfVector is only supported in struct array field, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName())
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
if fieldSchema.GetNullable() && typeutil.IsVectorType(fieldSchema.GetDataType()) {
|
||||
msg := fmt.Sprintf("vector type not support null, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName())
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
if fieldSchema.GetNullable() && fieldSchema.IsPrimaryKey {
|
||||
msg := fmt.Sprintf("primary field not support null, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName())
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
@ -502,11 +498,6 @@ func checkStructArrayFieldSchema(schemas []*schemapb.StructArrayFieldSchema) err
|
||||
field.DataType.String(), field.ElementType.String(), field.Name)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
if field.GetNullable() && typeutil.IsVectorType(field.ElementType) {
|
||||
msg := fmt.Sprintf("vector type not support null, data type:%s, element type:%s, name:%s",
|
||||
field.DataType.String(), field.ElementType.String(), field.Name)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
if field.GetDefaultValue() != nil {
|
||||
msg := fmt.Sprintf("fields in struct array field not support default_value, data type:%s, element type:%s, name:%s",
|
||||
field.DataType.String(), field.ElementType.String(), field.Name)
|
||||
|
||||
@ -391,8 +391,12 @@ func NewRecordBuilder(schema *schemapb.CollectionSchema) *RecordBuilder {
|
||||
if field.DataType == schemapb.DataType_ArrayOfVector {
|
||||
elementType = field.GetElementType()
|
||||
}
|
||||
arrowType := serdeMap[field.DataType].arrowType(int(dim), elementType)
|
||||
builders[i] = array.NewBuilder(memory.DefaultAllocator, arrowType)
|
||||
if field.GetNullable() && typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) {
|
||||
builders[i] = array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary)
|
||||
} else {
|
||||
arrowType := serdeMap[field.DataType].arrowType(int(dim), elementType)
|
||||
builders[i] = array.NewBuilder(memory.DefaultAllocator, arrowType)
|
||||
}
|
||||
}
|
||||
|
||||
return &RecordBuilder{
|
||||
|
||||
@ -448,19 +448,19 @@ func AddFieldDataToPayload(eventWriter *insertEventWriter, dataType schemapb.Dat
|
||||
}
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
if err = eventWriter.AddBinaryVectorToPayload(singleData.(*BinaryVectorFieldData).Data, singleData.(*BinaryVectorFieldData).Dim); err != nil {
|
||||
if err = eventWriter.AddBinaryVectorToPayload(singleData.(*BinaryVectorFieldData).Data, singleData.(*BinaryVectorFieldData).Dim, singleData.(*BinaryVectorFieldData).ValidData); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_FloatVector:
|
||||
if err = eventWriter.AddFloatVectorToPayload(singleData.(*FloatVectorFieldData).Data, singleData.(*FloatVectorFieldData).Dim); err != nil {
|
||||
if err = eventWriter.AddFloatVectorToPayload(singleData.(*FloatVectorFieldData).Data, singleData.(*FloatVectorFieldData).Dim, singleData.(*FloatVectorFieldData).ValidData); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_Float16Vector:
|
||||
if err = eventWriter.AddFloat16VectorToPayload(singleData.(*Float16VectorFieldData).Data, singleData.(*Float16VectorFieldData).Dim); err != nil {
|
||||
if err = eventWriter.AddFloat16VectorToPayload(singleData.(*Float16VectorFieldData).Data, singleData.(*Float16VectorFieldData).Dim, singleData.(*Float16VectorFieldData).ValidData); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
if err = eventWriter.AddBFloat16VectorToPayload(singleData.(*BFloat16VectorFieldData).Data, singleData.(*BFloat16VectorFieldData).Dim); err != nil {
|
||||
if err = eventWriter.AddBFloat16VectorToPayload(singleData.(*BFloat16VectorFieldData).Data, singleData.(*BFloat16VectorFieldData).Dim, singleData.(*BFloat16VectorFieldData).ValidData); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
@ -468,7 +468,7 @@ func AddFieldDataToPayload(eventWriter *insertEventWriter, dataType schemapb.Dat
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_Int8Vector:
|
||||
if err = eventWriter.AddInt8VectorToPayload(singleData.(*Int8VectorFieldData).Data, singleData.(*Int8VectorFieldData).Dim); err != nil {
|
||||
if err = eventWriter.AddInt8VectorToPayload(singleData.(*Int8VectorFieldData).Data, singleData.(*Int8VectorFieldData).Dim, singleData.(*Int8VectorFieldData).ValidData); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
@ -747,6 +747,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
|
||||
return length, err
|
||||
}
|
||||
binaryVectorFieldData.Dim = dim
|
||||
if validData != nil && len(validData) > 0 {
|
||||
startLogical := len(binaryVectorFieldData.ValidData)
|
||||
if binaryVectorFieldData.ValidData == nil {
|
||||
binaryVectorFieldData.ValidData = make([]bool, 0, rowNum)
|
||||
}
|
||||
binaryVectorFieldData.ValidData = append(binaryVectorFieldData.ValidData, validData...)
|
||||
binaryVectorFieldData.Nullable = true
|
||||
binaryVectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
|
||||
}
|
||||
insertData.Data[fieldID] = binaryVectorFieldData
|
||||
return length, nil
|
||||
|
||||
@ -763,6 +772,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
|
||||
return length, err
|
||||
}
|
||||
float16VectorFieldData.Dim = dim
|
||||
if validData != nil && len(validData) > 0 {
|
||||
startLogical := len(float16VectorFieldData.ValidData)
|
||||
if float16VectorFieldData.ValidData == nil {
|
||||
float16VectorFieldData.ValidData = make([]bool, 0, rowNum)
|
||||
}
|
||||
float16VectorFieldData.ValidData = append(float16VectorFieldData.ValidData, validData...)
|
||||
float16VectorFieldData.Nullable = true
|
||||
float16VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
|
||||
}
|
||||
insertData.Data[fieldID] = float16VectorFieldData
|
||||
return length, nil
|
||||
|
||||
@ -779,6 +797,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
|
||||
return length, err
|
||||
}
|
||||
bfloat16VectorFieldData.Dim = dim
|
||||
if validData != nil && len(validData) > 0 {
|
||||
startLogical := len(bfloat16VectorFieldData.ValidData)
|
||||
if bfloat16VectorFieldData.ValidData == nil {
|
||||
bfloat16VectorFieldData.ValidData = make([]bool, 0, rowNum)
|
||||
}
|
||||
bfloat16VectorFieldData.ValidData = append(bfloat16VectorFieldData.ValidData, validData...)
|
||||
bfloat16VectorFieldData.Nullable = true
|
||||
bfloat16VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
|
||||
}
|
||||
insertData.Data[fieldID] = bfloat16VectorFieldData
|
||||
return length, nil
|
||||
|
||||
@ -795,6 +822,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
|
||||
return 0, err
|
||||
}
|
||||
floatVectorFieldData.Dim = dim
|
||||
if validData != nil && len(validData) > 0 {
|
||||
startLogical := len(floatVectorFieldData.ValidData)
|
||||
if floatVectorFieldData.ValidData == nil {
|
||||
floatVectorFieldData.ValidData = make([]bool, 0, rowNum)
|
||||
}
|
||||
floatVectorFieldData.ValidData = append(floatVectorFieldData.ValidData, validData...)
|
||||
floatVectorFieldData.Nullable = true
|
||||
floatVectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
|
||||
}
|
||||
insertData.Data[fieldID] = floatVectorFieldData
|
||||
return length, nil
|
||||
|
||||
@ -805,6 +841,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
|
||||
}
|
||||
vec := fieldData.(*SparseFloatVectorFieldData)
|
||||
vec.AppendAllRows(singleData)
|
||||
if validData != nil && len(validData) > 0 {
|
||||
startLogical := len(vec.ValidData)
|
||||
if vec.ValidData == nil {
|
||||
vec.ValidData = make([]bool, 0, rowNum)
|
||||
}
|
||||
vec.ValidData = append(vec.ValidData, validData...)
|
||||
vec.Nullable = true
|
||||
vec.L2PMapping.Build(validData, startLogical, len(validData))
|
||||
}
|
||||
insertData.Data[fieldID] = vec
|
||||
return singleData.RowNum(), nil
|
||||
|
||||
@ -821,6 +866,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
|
||||
return 0, err
|
||||
}
|
||||
int8VectorFieldData.Dim = dim
|
||||
if validData != nil && len(validData) > 0 {
|
||||
startLogical := len(int8VectorFieldData.ValidData)
|
||||
if int8VectorFieldData.ValidData == nil {
|
||||
int8VectorFieldData.ValidData = make([]bool, 0, rowNum)
|
||||
}
|
||||
int8VectorFieldData.ValidData = append(int8VectorFieldData.ValidData, validData...)
|
||||
int8VectorFieldData.Nullable = true
|
||||
int8VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
|
||||
}
|
||||
insertData.Data[fieldID] = int8VectorFieldData
|
||||
return length, nil
|
||||
|
||||
|
||||
@ -35,30 +35,36 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
CollectionID = 1
|
||||
PartitionID = 1
|
||||
SegmentID = 1
|
||||
RowIDField = 0
|
||||
TimestampField = 1
|
||||
BoolField = 100
|
||||
Int8Field = 101
|
||||
Int16Field = 102
|
||||
Int32Field = 103
|
||||
Int64Field = 104
|
||||
FloatField = 105
|
||||
DoubleField = 106
|
||||
StringField = 107
|
||||
BinaryVectorField = 108
|
||||
FloatVectorField = 109
|
||||
ArrayField = 110
|
||||
JSONField = 111
|
||||
Float16VectorField = 112
|
||||
BFloat16VectorField = 113
|
||||
SparseFloatVectorField = 114
|
||||
Int8VectorField = 115
|
||||
StructField = 116
|
||||
StructSubInt32Field = 117
|
||||
StructSubFloatVectorField = 118
|
||||
CollectionID = 1
|
||||
PartitionID = 1
|
||||
SegmentID = 1
|
||||
RowIDField = 0
|
||||
TimestampField = 1
|
||||
BoolField = 100
|
||||
Int8Field = 101
|
||||
Int16Field = 102
|
||||
Int32Field = 103
|
||||
Int64Field = 104
|
||||
FloatField = 105
|
||||
DoubleField = 106
|
||||
StringField = 107
|
||||
BinaryVectorField = 108
|
||||
FloatVectorField = 109
|
||||
ArrayField = 110
|
||||
JSONField = 111
|
||||
Float16VectorField = 112
|
||||
BFloat16VectorField = 113
|
||||
SparseFloatVectorField = 114
|
||||
Int8VectorField = 115
|
||||
StructField = 116
|
||||
StructSubInt32Field = 117
|
||||
StructSubFloatVectorField = 118
|
||||
NullableFloatVectorField = 119
|
||||
NullableBinaryVectorField = 120
|
||||
NullableFloat16VectorField = 121
|
||||
NullableBFloat16VectorField = 122
|
||||
NullableInt8VectorField = 123
|
||||
NullableSparseFloatVectorField = 124
|
||||
)
|
||||
|
||||
func assertTestData(t *testing.T, i int, value *Value) {
|
||||
@ -284,26 +290,35 @@ func generateTestDataWithSeed(seed, num int) ([]*Blob, error) {
|
||||
19: &JSONFieldData{Data: field19},
|
||||
101: &Int32FieldData{Data: field101},
|
||||
102: &FloatVectorFieldData{
|
||||
Data: field102,
|
||||
Dim: 8,
|
||||
Data: field102,
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
103: &BinaryVectorFieldData{
|
||||
Data: field103,
|
||||
Dim: 8,
|
||||
Data: field103,
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
104: &Float16VectorFieldData{
|
||||
Data: field104,
|
||||
Dim: 8,
|
||||
Data: field104,
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
105: &BFloat16VectorFieldData{
|
||||
Data: field105,
|
||||
Dim: 8,
|
||||
Data: field105,
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
106: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 28433,
|
||||
Contents: field106,
|
||||
},
|
||||
Nullable: false,
|
||||
},
|
||||
}}
|
||||
|
||||
@ -616,6 +631,79 @@ func genTestCollectionMeta() *etcdpb.CollectionMeta {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: NullableFloatVectorField,
|
||||
Name: "field_nullable_float_vector",
|
||||
Description: "nullable_float_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "4",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: NullableBinaryVectorField,
|
||||
Name: "field_nullable_binary_vector",
|
||||
Description: "nullable_binary_vector",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: NullableFloat16VectorField,
|
||||
Name: "field_nullable_float16_vector",
|
||||
Description: "nullable_float16_vector",
|
||||
DataType: schemapb.DataType_Float16Vector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "4",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: NullableBFloat16VectorField,
|
||||
Name: "field_nullable_bfloat16_vector",
|
||||
Description: "nullable_bfloat16_vector",
|
||||
DataType: schemapb.DataType_BFloat16Vector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "4",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: NullableInt8VectorField,
|
||||
Name: "field_nullable_int8_vector",
|
||||
Description: "nullable_int8_vector",
|
||||
DataType: schemapb.DataType_Int8Vector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "4",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: NullableSparseFloatVectorField,
|
||||
Name: "field_nullable_sparse_float_vector",
|
||||
Description: "nullable_sparse_float_vector",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
Nullable: true,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
@ -649,63 +737,6 @@ func genTestCollectionMeta() *etcdpb.CollectionMeta {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertCodecFailed(t *testing.T) {
|
||||
t.Run("vector field not support null", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
description string
|
||||
dataType schemapb.DataType
|
||||
}{
|
||||
{"nullable FloatVector field", schemapb.DataType_FloatVector},
|
||||
{"nullable Float16Vector field", schemapb.DataType_Float16Vector},
|
||||
{"nullable BinaryVector field", schemapb.DataType_BinaryVector},
|
||||
{"nullable BFloat16Vector field", schemapb.DataType_BFloat16Vector},
|
||||
{"nullable SparseFloatVector field", schemapb.DataType_SparseFloatVector},
|
||||
{"nullable Int8Vector field", schemapb.DataType_Int8Vector},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
schema := &etcdpb.CollectionMeta{
|
||||
ID: CollectionID,
|
||||
CreateTime: 1,
|
||||
SegmentIDs: []int64{SegmentID},
|
||||
PartitionTags: []string{"partition_0", "partition_1"},
|
||||
Schema: &schemapb.CollectionSchema{
|
||||
Name: "schema",
|
||||
Description: "schema",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: RowIDField,
|
||||
Name: "row_id",
|
||||
Description: "row_id",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: TimestampField,
|
||||
Name: "Timestamp",
|
||||
Description: "Timestamp",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
DataType: test.dataType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
insertCodec := NewInsertCodecWithSchema(schema)
|
||||
insertDataEmpty := &InsertData{
|
||||
Data: map[int64]FieldData{
|
||||
RowIDField: &Int64FieldData{[]int64{}, nil, false},
|
||||
TimestampField: &Int64FieldData{[]int64{}, nil, false},
|
||||
},
|
||||
}
|
||||
_, err := insertCodec.Serialize(PartitionID, SegmentID, insertDataEmpty)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertCodec(t *testing.T) {
|
||||
schema := genTestCollectionMeta()
|
||||
insertCodec := NewInsertCodecWithSchema(schema)
|
||||
@ -742,12 +773,16 @@ func TestInsertCodec(t *testing.T) {
|
||||
Data: []string{"3", "4"},
|
||||
},
|
||||
BinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{0, 255},
|
||||
Dim: 8,
|
||||
Data: []byte{0, 255},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
FloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{4, 5, 6, 7, 4, 5, 6, 7},
|
||||
Dim: 4,
|
||||
Data: []float32{4, 5, 6, 7, 4, 5, 6, 7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
ArrayField: &ArrayFieldData{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
@ -772,13 +807,17 @@ func TestInsertCodec(t *testing.T) {
|
||||
},
|
||||
Float16VectorField: &Float16VectorFieldData{
|
||||
// length = 2 * Dim * numRows(2) = 16
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
Dim: 4,
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{
|
||||
// length = 2 * Dim * numRows(2) = 16
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
Dim: 4,
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
SparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
@ -789,10 +828,13 @@ func TestInsertCodec(t *testing.T) {
|
||||
typeutil.CreateSparseFloatRow([]uint32{100, 200, 599}, []float32{3.1, 3.2, 3.3}),
|
||||
},
|
||||
},
|
||||
Nullable: false,
|
||||
},
|
||||
Int8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{-4, -5, -6, -7, -4, -5, -6, -7},
|
||||
Dim: 4,
|
||||
Data: []int8{-4, -5, -6, -7, -4, -5, -6, -7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
StructSubInt32Field: &ArrayFieldData{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
@ -827,6 +869,36 @@ func TestInsertCodec(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
NullableFloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{4.0, 5.0, 6.0, 7.0},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{255},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 8,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableFloat16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{255, 0, 255, 0, 255, 0, 255, 0},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{255, 0, 255, 0, 255, 0, 255, 0},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableInt8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{-4, -5, -6, -7},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -863,22 +935,30 @@ func TestInsertCodec(t *testing.T) {
|
||||
Data: []string{"1", "2"},
|
||||
},
|
||||
BinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{0, 255},
|
||||
Dim: 8,
|
||||
Data: []byte{0, 255},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
FloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
Dim: 4,
|
||||
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Float16VectorField: &Float16VectorFieldData{
|
||||
// length = 2 * Dim * numRows(2) = 16
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
Dim: 4,
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{
|
||||
// length = 2 * Dim * numRows(2) = 16
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
Dim: 4,
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
SparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
@ -889,10 +969,13 @@ func TestInsertCodec(t *testing.T) {
|
||||
typeutil.CreateSparseFloatRow([]uint32{105, 207, 299}, []float32{3.1, 3.2, 3.3}),
|
||||
},
|
||||
},
|
||||
Nullable: false,
|
||||
},
|
||||
Int8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
Dim: 4,
|
||||
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
StructSubInt32Field: &ArrayFieldData{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
@ -927,6 +1010,46 @@ func TestInsertCodec(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
NullableFloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{0.0, 1.0, 2.0, 3.0},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{0},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 8,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableFloat16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableInt8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{0, 1, 2, 3},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 300,
|
||||
Contents: [][]byte{
|
||||
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true, false},
|
||||
Nullable: true,
|
||||
},
|
||||
ArrayField: &ArrayFieldData{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Data: []*schemapb.ScalarField{
|
||||
@ -953,27 +1076,84 @@ func TestInsertCodec(t *testing.T) {
|
||||
|
||||
insertDataEmpty := &InsertData{
|
||||
Data: map[int64]FieldData{
|
||||
RowIDField: &Int64FieldData{[]int64{}, nil, false},
|
||||
TimestampField: &Int64FieldData{[]int64{}, nil, false},
|
||||
BoolField: &BoolFieldData{[]bool{}, nil, false},
|
||||
Int8Field: &Int8FieldData{[]int8{}, nil, false},
|
||||
Int16Field: &Int16FieldData{[]int16{}, nil, false},
|
||||
Int32Field: &Int32FieldData{[]int32{}, nil, false},
|
||||
Int64Field: &Int64FieldData{[]int64{}, nil, false},
|
||||
FloatField: &FloatFieldData{[]float32{}, nil, false},
|
||||
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
|
||||
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
|
||||
BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8},
|
||||
FloatVectorField: &FloatVectorFieldData{[]float32{}, 4},
|
||||
Float16VectorField: &Float16VectorFieldData{[]byte{}, 4},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{[]byte{}, 4},
|
||||
RowIDField: &Int64FieldData{[]int64{}, nil, false},
|
||||
TimestampField: &Int64FieldData{[]int64{}, nil, false},
|
||||
BoolField: &BoolFieldData{[]bool{}, nil, false},
|
||||
Int8Field: &Int8FieldData{[]int8{}, nil, false},
|
||||
Int16Field: &Int16FieldData{[]int16{}, nil, false},
|
||||
Int32Field: &Int32FieldData{[]int32{}, nil, false},
|
||||
Int64Field: &Int64FieldData{[]int64{}, nil, false},
|
||||
FloatField: &FloatFieldData{[]float32{}, nil, false},
|
||||
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
|
||||
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
|
||||
BinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
FloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Float16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
SparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 0,
|
||||
Contents: [][]byte{},
|
||||
},
|
||||
ValidData: nil,
|
||||
Nullable: false,
|
||||
},
|
||||
Int8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
NullableFloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: []bool{},
|
||||
Dim: 8,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableFloat16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableInt8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
Int8VectorField: &Int8VectorFieldData{[]int8{}, 4},
|
||||
StructSubInt32Field: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}, nil, false},
|
||||
ArrayField: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}, nil, false},
|
||||
JSONField: &JSONFieldData{[][]byte{}, nil, false},
|
||||
@ -1321,24 +1501,74 @@ func TestMemorySize(t *testing.T) {
|
||||
Data: []string{"3"},
|
||||
},
|
||||
BinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{0},
|
||||
Dim: 8,
|
||||
Data: []byte{0},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
FloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{4, 5, 6, 7},
|
||||
Dim: 4,
|
||||
Data: []float32{4, 5, 6, 7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Float16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
|
||||
Dim: 4,
|
||||
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
|
||||
Dim: 4,
|
||||
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Int8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{4, 5, 6, 7},
|
||||
Dim: 4,
|
||||
Data: []int8{4, 5, 6, 7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
NullableFloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{4.0, 5.0, 6.0, 7.0},
|
||||
ValidData: []bool{true},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{255},
|
||||
ValidData: []bool{true},
|
||||
Dim: 8,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableFloat16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{0xff, 0x0, 0xff, 0x0, 0xff, 0x0, 0xff, 0x0},
|
||||
ValidData: []bool{true},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{0xff, 0x0, 0xff, 0x0, 0xff, 0x0, 0xff, 0x0},
|
||||
ValidData: []bool{true},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableInt8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{4, 5, 6, 7},
|
||||
ValidData: []bool{true},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 300,
|
||||
Contents: [][]byte{
|
||||
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true},
|
||||
Nullable: true,
|
||||
},
|
||||
ArrayField: &ArrayFieldData{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
@ -1389,15 +1619,21 @@ func TestMemorySize(t *testing.T) {
|
||||
assert.Equal(t, insertData1.Data[FloatField].GetMemorySize(), 5)
|
||||
assert.Equal(t, insertData1.Data[DoubleField].GetMemorySize(), 9)
|
||||
assert.Equal(t, insertData1.Data[StringField].GetMemorySize(), 18)
|
||||
assert.Equal(t, insertData1.Data[BinaryVectorField].GetMemorySize(), 5)
|
||||
assert.Equal(t, insertData1.Data[FloatVectorField].GetMemorySize(), 20)
|
||||
assert.Equal(t, insertData1.Data[Float16VectorField].GetMemorySize(), 12)
|
||||
assert.Equal(t, insertData1.Data[BFloat16VectorField].GetMemorySize(), 12)
|
||||
assert.Equal(t, insertData1.Data[Int8VectorField].GetMemorySize(), 8)
|
||||
assert.Equal(t, insertData1.Data[ArrayField].GetMemorySize(), 3*4+1)
|
||||
assert.Equal(t, insertData1.Data[JSONField].GetMemorySize(), len([]byte(`{"batch":1}`))+16+1)
|
||||
assert.Equal(t, insertData1.Data[StructSubInt32Field].GetMemorySize(), 4*4+1)
|
||||
assert.Equal(t, insertData1.Data[StructSubFloatVectorField].GetMemorySize(), 4*4+4)
|
||||
assert.Equal(t, insertData1.Data[BinaryVectorField].GetMemorySize(), 14)
|
||||
assert.Equal(t, insertData1.Data[FloatVectorField].GetMemorySize(), 29)
|
||||
assert.Equal(t, insertData1.Data[Float16VectorField].GetMemorySize(), 21)
|
||||
assert.Equal(t, insertData1.Data[BFloat16VectorField].GetMemorySize(), 21)
|
||||
assert.Equal(t, insertData1.Data[Int8VectorField].GetMemorySize(), 17)
|
||||
assert.Equal(t, insertData1.Data[NullableFloatVectorField].GetMemorySize(), 30)
|
||||
assert.Equal(t, insertData1.Data[NullableBinaryVectorField].GetMemorySize(), 15)
|
||||
assert.Equal(t, insertData1.Data[NullableFloat16VectorField].GetMemorySize(), 22)
|
||||
assert.Equal(t, insertData1.Data[NullableBFloat16VectorField].GetMemorySize(), 22)
|
||||
assert.Equal(t, insertData1.Data[NullableInt8VectorField].GetMemorySize(), 18)
|
||||
assert.Equal(t, insertData1.Data[NullableSparseFloatVectorField].GetMemorySize(), 39)
|
||||
assert.Equal(t, insertData1.Data[ArrayField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertData1.Data[JSONField].GetMemorySize(), 28)
|
||||
assert.Equal(t, insertData1.Data[StructSubInt32Field].GetMemorySize(), 17)
|
||||
assert.Equal(t, insertData1.Data[StructSubFloatVectorField].GetMemorySize(), 20)
|
||||
|
||||
insertData2 := &InsertData{
|
||||
Data: map[int64]FieldData{
|
||||
@ -1432,24 +1668,84 @@ func TestMemorySize(t *testing.T) {
|
||||
Data: []string{"1", "23"},
|
||||
},
|
||||
BinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{0, 255},
|
||||
Dim: 8,
|
||||
Data: []byte{0, 255},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
FloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
Dim: 4,
|
||||
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Float16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
|
||||
Dim: 4,
|
||||
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
|
||||
Dim: 4,
|
||||
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Int8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
Dim: 4,
|
||||
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
SparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 300,
|
||||
Contents: [][]byte{
|
||||
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{1.1, 1.2, 1.3}),
|
||||
typeutil.CreateSparseFloatRow([]uint32{10, 20, 30}, []float32{2.1, 2.2, 2.3}),
|
||||
},
|
||||
},
|
||||
Nullable: false,
|
||||
},
|
||||
NullableFloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{0.0, 1.0, 2.0, 3.0},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{0},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 8,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableFloat16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableInt8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{0, 1, 2, 3},
|
||||
ValidData: []bool{true, false},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 300,
|
||||
Contents: [][]byte{
|
||||
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true, false},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -1464,29 +1760,107 @@ func TestMemorySize(t *testing.T) {
|
||||
assert.Equal(t, insertData2.Data[FloatField].GetMemorySize(), 9)
|
||||
assert.Equal(t, insertData2.Data[DoubleField].GetMemorySize(), 17)
|
||||
assert.Equal(t, insertData2.Data[StringField].GetMemorySize(), 36)
|
||||
assert.Equal(t, insertData2.Data[BinaryVectorField].GetMemorySize(), 6)
|
||||
assert.Equal(t, insertData2.Data[FloatVectorField].GetMemorySize(), 36)
|
||||
assert.Equal(t, insertData2.Data[Float16VectorField].GetMemorySize(), 20)
|
||||
assert.Equal(t, insertData2.Data[BFloat16VectorField].GetMemorySize(), 20)
|
||||
assert.Equal(t, insertData2.Data[Int8VectorField].GetMemorySize(), 12)
|
||||
assert.Equal(t, insertData2.Data[BinaryVectorField].GetMemorySize(), 15)
|
||||
assert.Equal(t, insertData2.Data[FloatVectorField].GetMemorySize(), 45)
|
||||
assert.Equal(t, insertData2.Data[Float16VectorField].GetMemorySize(), 29)
|
||||
assert.Equal(t, insertData2.Data[BFloat16VectorField].GetMemorySize(), 29)
|
||||
assert.Equal(t, insertData2.Data[Int8VectorField].GetMemorySize(), 21)
|
||||
assert.Equal(t, insertData2.Data[SparseFloatVectorField].GetMemorySize(), 64)
|
||||
assert.Equal(t, insertData2.Data[NullableBinaryVectorField].GetMemorySize(), 16)
|
||||
assert.Equal(t, insertData2.Data[NullableFloatVectorField].GetMemorySize(), 31)
|
||||
assert.Equal(t, insertData2.Data[NullableFloat16VectorField].GetMemorySize(), 23)
|
||||
assert.Equal(t, insertData2.Data[NullableBFloat16VectorField].GetMemorySize(), 23)
|
||||
assert.Equal(t, insertData2.Data[NullableInt8VectorField].GetMemorySize(), 19)
|
||||
assert.Equal(t, insertData2.Data[NullableSparseFloatVectorField].GetMemorySize(), 40)
|
||||
|
||||
insertDataEmpty := &InsertData{
|
||||
Data: map[int64]FieldData{
|
||||
RowIDField: &Int64FieldData{[]int64{}, nil, false},
|
||||
TimestampField: &Int64FieldData{[]int64{}, nil, false},
|
||||
BoolField: &BoolFieldData{[]bool{}, nil, false},
|
||||
Int8Field: &Int8FieldData{[]int8{}, nil, false},
|
||||
Int16Field: &Int16FieldData{[]int16{}, nil, false},
|
||||
Int32Field: &Int32FieldData{[]int32{}, nil, false},
|
||||
Int64Field: &Int64FieldData{[]int64{}, nil, false},
|
||||
FloatField: &FloatFieldData{[]float32{}, nil, false},
|
||||
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
|
||||
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
|
||||
BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8},
|
||||
FloatVectorField: &FloatVectorFieldData{[]float32{}, 4},
|
||||
Float16VectorField: &Float16VectorFieldData{[]byte{}, 4},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{[]byte{}, 4},
|
||||
Int8VectorField: &Int8VectorFieldData{[]int8{}, 4},
|
||||
RowIDField: &Int64FieldData{[]int64{}, nil, false},
|
||||
TimestampField: &Int64FieldData{[]int64{}, nil, false},
|
||||
BoolField: &BoolFieldData{[]bool{}, nil, false},
|
||||
Int8Field: &Int8FieldData{[]int8{}, nil, false},
|
||||
Int16Field: &Int16FieldData{[]int16{}, nil, false},
|
||||
Int32Field: &Int32FieldData{[]int32{}, nil, false},
|
||||
Int64Field: &Int64FieldData{[]int64{}, nil, false},
|
||||
FloatField: &FloatFieldData{[]float32{}, nil, false},
|
||||
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
|
||||
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
|
||||
BinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
},
|
||||
FloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Float16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
BFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
Int8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
},
|
||||
SparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 0,
|
||||
Contents: [][]byte{},
|
||||
},
|
||||
ValidData: nil,
|
||||
Nullable: false,
|
||||
},
|
||||
NullableFloatVectorField: &FloatVectorFieldData{
|
||||
Data: []float32{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBinaryVectorField: &BinaryVectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: []bool{},
|
||||
Dim: 8,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableFloat16VectorField: &Float16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableBFloat16VectorField: &BFloat16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableInt8VectorField: &Int8VectorFieldData{
|
||||
Data: []int8{},
|
||||
ValidData: []bool{},
|
||||
Dim: 4,
|
||||
Nullable: true,
|
||||
},
|
||||
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 0,
|
||||
Contents: [][]byte{},
|
||||
},
|
||||
ValidData: []bool{},
|
||||
Nullable: true,
|
||||
},
|
||||
StructSubFloatVectorField: &VectorArrayFieldData{
|
||||
Dim: 2,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
@ -1505,11 +1879,18 @@ func TestMemorySize(t *testing.T) {
|
||||
assert.Equal(t, insertDataEmpty.Data[FloatField].GetMemorySize(), 1)
|
||||
assert.Equal(t, insertDataEmpty.Data[DoubleField].GetMemorySize(), 1)
|
||||
assert.Equal(t, insertDataEmpty.Data[StringField].GetMemorySize(), 1)
|
||||
assert.Equal(t, insertDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4)
|
||||
assert.Equal(t, insertDataEmpty.Data[FloatVectorField].GetMemorySize(), 4)
|
||||
assert.Equal(t, insertDataEmpty.Data[Float16VectorField].GetMemorySize(), 4)
|
||||
assert.Equal(t, insertDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4)
|
||||
assert.Equal(t, insertDataEmpty.Data[Int8VectorField].GetMemorySize(), 4)
|
||||
assert.Equal(t, insertDataEmpty.Data[BinaryVectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[FloatVectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[Float16VectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[Int8VectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 9)
|
||||
assert.Equal(t, insertDataEmpty.Data[NullableFloatVectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[NullableBinaryVectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[NullableFloat16VectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[NullableBFloat16VectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[NullableInt8VectorField].GetMemorySize(), 13)
|
||||
assert.Equal(t, insertDataEmpty.Data[NullableSparseFloatVectorField].GetMemorySize(), 9)
|
||||
assert.Equal(t, insertDataEmpty.Data[StructSubFloatVectorField].GetMemorySize(), 0)
|
||||
}
|
||||
|
||||
@ -1589,22 +1970,49 @@ func TestAddFieldDataToPayload(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_JSON, &JSONFieldData{[][]byte{[]byte(`"batch":2}`)}, nil, false})
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_BinaryVector, &BinaryVectorFieldData{[]byte{}, 8})
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_BinaryVector, &BinaryVectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_FloatVector, &FloatVectorFieldData{[]float32{}, 4})
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_FloatVector, &FloatVectorFieldData{
|
||||
Data: []float32{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_Float16Vector, &Float16VectorFieldData{[]byte{}, 4})
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_Float16Vector, &Float16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_BFloat16Vector, &BFloat16VectorFieldData{[]byte{}, 8})
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_BFloat16Vector, &BFloat16VectorFieldData{
|
||||
Data: []byte{},
|
||||
ValidData: nil,
|
||||
Dim: 8,
|
||||
Nullable: false,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_SparseFloatVector, &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 0,
|
||||
Contents: [][]byte{},
|
||||
},
|
||||
ValidData: nil,
|
||||
Nullable: false,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_Int8Vector, &Int8VectorFieldData{[]int8{}, 4})
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_Int8Vector, &Int8VectorFieldData{
|
||||
Data: []int8{},
|
||||
ValidData: nil,
|
||||
Dim: 4,
|
||||
Nullable: false,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
err = AddFieldDataToPayload(e, schemapb.DataType_ArrayOfVector, &VectorArrayFieldData{
|
||||
Dim: 2,
|
||||
|
||||
@ -195,70 +195,83 @@ func NewFieldData(dataType schemapb.DataType, fieldSchema *schemapb.FieldSchema,
|
||||
typeParams := fieldSchema.GetTypeParams()
|
||||
switch dataType {
|
||||
case schemapb.DataType_Float16Vector:
|
||||
if fieldSchema.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
|
||||
}
|
||||
dim, err := GetDimFromParams(typeParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Float16VectorFieldData{
|
||||
Data: make([]byte, 0, cap),
|
||||
Dim: dim,
|
||||
}, nil
|
||||
data := &Float16VectorFieldData{
|
||||
Data: make([]byte, 0, cap),
|
||||
Dim: dim,
|
||||
Nullable: fieldSchema.GetNullable(),
|
||||
}
|
||||
if fieldSchema.GetNullable() {
|
||||
data.ValidData = make([]bool, 0, cap)
|
||||
}
|
||||
return data, nil
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
if fieldSchema.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
|
||||
}
|
||||
dim, err := GetDimFromParams(typeParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &BFloat16VectorFieldData{
|
||||
Data: make([]byte, 0, cap),
|
||||
Dim: dim,
|
||||
}, nil
|
||||
data := &BFloat16VectorFieldData{
|
||||
Data: make([]byte, 0, cap),
|
||||
Dim: dim,
|
||||
Nullable: fieldSchema.GetNullable(),
|
||||
}
|
||||
if fieldSchema.GetNullable() {
|
||||
data.ValidData = make([]bool, 0, cap)
|
||||
}
|
||||
return data, nil
|
||||
case schemapb.DataType_FloatVector:
|
||||
if fieldSchema.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
|
||||
}
|
||||
dim, err := GetDimFromParams(typeParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FloatVectorFieldData{
|
||||
Data: make([]float32, 0, cap),
|
||||
Dim: dim,
|
||||
}, nil
|
||||
data := &FloatVectorFieldData{
|
||||
Data: make([]float32, 0, cap),
|
||||
Dim: dim,
|
||||
Nullable: fieldSchema.GetNullable(),
|
||||
}
|
||||
if fieldSchema.GetNullable() {
|
||||
data.ValidData = make([]bool, 0, cap)
|
||||
}
|
||||
return data, nil
|
||||
case schemapb.DataType_BinaryVector:
|
||||
if fieldSchema.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
|
||||
}
|
||||
dim, err := GetDimFromParams(typeParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &BinaryVectorFieldData{
|
||||
Data: make([]byte, 0, cap),
|
||||
Dim: dim,
|
||||
}, nil
|
||||
data := &BinaryVectorFieldData{
|
||||
Data: make([]byte, 0, cap),
|
||||
Dim: dim,
|
||||
Nullable: fieldSchema.GetNullable(),
|
||||
}
|
||||
if fieldSchema.GetNullable() {
|
||||
data.ValidData = make([]bool, 0, cap)
|
||||
}
|
||||
return data, nil
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
if fieldSchema.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
|
||||
data := &SparseFloatVectorFieldData{
|
||||
Nullable: fieldSchema.GetNullable(),
|
||||
}
|
||||
return &SparseFloatVectorFieldData{}, nil
|
||||
if fieldSchema.GetNullable() {
|
||||
data.ValidData = make([]bool, 0, cap)
|
||||
}
|
||||
return data, nil
|
||||
case schemapb.DataType_Int8Vector:
|
||||
if fieldSchema.GetNullable() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
|
||||
}
|
||||
dim, err := GetDimFromParams(typeParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Int8VectorFieldData{
|
||||
Data: make([]int8, 0, cap),
|
||||
Dim: dim,
|
||||
}, nil
|
||||
data := &Int8VectorFieldData{
|
||||
Data: make([]int8, 0, cap),
|
||||
Dim: dim,
|
||||
Nullable: fieldSchema.GetNullable(),
|
||||
}
|
||||
if fieldSchema.GetNullable() {
|
||||
data.ValidData = make([]bool, 0, cap)
|
||||
}
|
||||
return data, nil
|
||||
case schemapb.DataType_Bool:
|
||||
data := &BoolFieldData{
|
||||
Data: make([]bool, 0, cap),
|
||||
@ -453,30 +466,97 @@ type GeometryFieldData struct {
|
||||
ValidData []bool
|
||||
Nullable bool
|
||||
}
|
||||
|
||||
// LogicalToPhysicalMapping maps logical offset to physical offset for nullable vector
|
||||
type LogicalToPhysicalMapping struct {
|
||||
validCount int
|
||||
l2pMap map[int]int
|
||||
}
|
||||
|
||||
func (m *LogicalToPhysicalMapping) GetPhysicalOffset(logicalOffset int) int {
|
||||
if m.l2pMap == nil {
|
||||
return logicalOffset
|
||||
}
|
||||
if physicalOffset, ok := m.l2pMap[logicalOffset]; ok {
|
||||
return physicalOffset
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (m *LogicalToPhysicalMapping) GetMemorySize() int {
|
||||
size := 8 // validCount int
|
||||
size += len(m.l2pMap) * 16 // map[int]int, roughly 16 bytes per entry
|
||||
return size
|
||||
}
|
||||
|
||||
func (m *LogicalToPhysicalMapping) GetValidCount() int {
|
||||
return m.validCount
|
||||
}
|
||||
|
||||
func (m *LogicalToPhysicalMapping) Build(validData []bool, startLogical, totalCount int) {
|
||||
if totalCount == 0 {
|
||||
return
|
||||
}
|
||||
if len(validData) < totalCount {
|
||||
return
|
||||
}
|
||||
|
||||
if m.l2pMap == nil {
|
||||
m.l2pMap = make(map[int]int)
|
||||
}
|
||||
|
||||
physicalIdx := m.validCount
|
||||
for i := 0; i < totalCount; i++ {
|
||||
if validData[i] {
|
||||
m.l2pMap[startLogical+i] = physicalIdx
|
||||
physicalIdx++
|
||||
}
|
||||
}
|
||||
m.validCount = physicalIdx
|
||||
}
|
||||
|
||||
type BinaryVectorFieldData struct {
|
||||
Data []byte
|
||||
Dim int
|
||||
Data []byte
|
||||
ValidData []bool
|
||||
Dim int
|
||||
Nullable bool
|
||||
L2PMapping LogicalToPhysicalMapping
|
||||
}
|
||||
type FloatVectorFieldData struct {
|
||||
Data []float32
|
||||
Dim int
|
||||
Data []float32
|
||||
ValidData []bool
|
||||
Dim int
|
||||
Nullable bool
|
||||
L2PMapping LogicalToPhysicalMapping
|
||||
}
|
||||
type Float16VectorFieldData struct {
|
||||
Data []byte
|
||||
Dim int
|
||||
Data []byte
|
||||
ValidData []bool
|
||||
Dim int
|
||||
Nullable bool
|
||||
L2PMapping LogicalToPhysicalMapping
|
||||
}
|
||||
type BFloat16VectorFieldData struct {
|
||||
Data []byte
|
||||
Dim int
|
||||
Data []byte
|
||||
ValidData []bool
|
||||
Dim int
|
||||
Nullable bool
|
||||
L2PMapping LogicalToPhysicalMapping
|
||||
}
|
||||
|
||||
type SparseFloatVectorFieldData struct {
|
||||
schemapb.SparseFloatArray
|
||||
ValidData []bool
|
||||
Nullable bool
|
||||
L2PMapping LogicalToPhysicalMapping
|
||||
}
|
||||
|
||||
type Int8VectorFieldData struct {
|
||||
Data []int8
|
||||
Dim int
|
||||
Data []int8
|
||||
ValidData []bool
|
||||
Dim int
|
||||
Nullable bool
|
||||
L2PMapping LogicalToPhysicalMapping
|
||||
}
|
||||
|
||||
type VectorArrayFieldData struct {
|
||||
@ -493,29 +573,71 @@ func (dst *SparseFloatVectorFieldData) AppendAllRows(src *SparseFloatVectorField
|
||||
dst.Dim = src.Dim
|
||||
}
|
||||
dst.Contents = append(dst.Contents, src.Contents...)
|
||||
if src.Nullable {
|
||||
if dst.ValidData == nil {
|
||||
dst.ValidData = make([]bool, 0, len(src.ValidData))
|
||||
}
|
||||
dst.L2PMapping.Build(src.ValidData, len(dst.ValidData), len(src.ValidData))
|
||||
dst.ValidData = append(dst.ValidData, src.ValidData...)
|
||||
dst.Nullable = true
|
||||
}
|
||||
}
|
||||
|
||||
// RowNum implements FieldData.RowNum
|
||||
func (data *BoolFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int8FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int16FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int32FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int64FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *FloatFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *DoubleFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *TimestamptzFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *StringFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *ArrayFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *JSONFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *GeometryFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *BinaryVectorFieldData) RowNum() int { return len(data.Data) * 8 / data.Dim }
|
||||
func (data *FloatVectorFieldData) RowNum() int { return len(data.Data) / data.Dim }
|
||||
func (data *Float16VectorFieldData) RowNum() int { return len(data.Data) / 2 / data.Dim }
|
||||
func (data *BFloat16VectorFieldData) RowNum() int {
|
||||
func (data *BoolFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int8FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int16FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int32FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *Int64FieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *FloatFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *DoubleFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *TimestamptzFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *StringFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *ArrayFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *JSONFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *GeometryFieldData) RowNum() int { return len(data.Data) }
|
||||
func (data *BinaryVectorFieldData) RowNum() int {
|
||||
if data.Nullable {
|
||||
return len(data.ValidData)
|
||||
}
|
||||
return len(data.Data) * 8 / data.Dim
|
||||
}
|
||||
|
||||
func (data *FloatVectorFieldData) RowNum() int {
|
||||
if data.Nullable {
|
||||
return len(data.ValidData)
|
||||
}
|
||||
return len(data.Data) / data.Dim
|
||||
}
|
||||
|
||||
func (data *Float16VectorFieldData) RowNum() int {
|
||||
if data.Nullable {
|
||||
return len(data.ValidData)
|
||||
}
|
||||
return len(data.Data) / 2 / data.Dim
|
||||
}
|
||||
func (data *SparseFloatVectorFieldData) RowNum() int { return len(data.Contents) }
|
||||
func (data *Int8VectorFieldData) RowNum() int { return len(data.Data) / data.Dim }
|
||||
|
||||
func (data *BFloat16VectorFieldData) RowNum() int {
|
||||
if data.Nullable {
|
||||
return len(data.ValidData)
|
||||
}
|
||||
return len(data.Data) / 2 / data.Dim
|
||||
}
|
||||
|
||||
func (data *SparseFloatVectorFieldData) RowNum() int {
|
||||
if data.Nullable {
|
||||
return len(data.ValidData)
|
||||
}
|
||||
return len(data.Contents)
|
||||
}
|
||||
|
||||
func (data *Int8VectorFieldData) RowNum() int {
|
||||
if data.Nullable {
|
||||
return len(data.ValidData)
|
||||
}
|
||||
return len(data.Data) / data.Dim
|
||||
}
|
||||
|
||||
func (data *VectorArrayFieldData) RowNum() int {
|
||||
return len(data.Data)
|
||||
}
|
||||
@ -606,27 +728,51 @@ func (data *GeometryFieldData) GetRow(i int) any {
|
||||
}
|
||||
|
||||
func (data *BinaryVectorFieldData) GetRow(i int) any {
|
||||
return data.Data[i*data.Dim/8 : (i+1)*data.Dim/8]
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return nil
|
||||
}
|
||||
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
|
||||
return data.Data[physicalIdx*data.Dim/8 : (physicalIdx+1)*data.Dim/8]
|
||||
}
|
||||
|
||||
func (data *SparseFloatVectorFieldData) GetRow(i int) interface{} {
|
||||
return data.Contents[i]
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return nil
|
||||
}
|
||||
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
|
||||
return data.Contents[physicalIdx]
|
||||
}
|
||||
|
||||
func (data *FloatVectorFieldData) GetRow(i int) interface{} {
|
||||
return data.Data[i*data.Dim : (i+1)*data.Dim]
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return nil
|
||||
}
|
||||
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
|
||||
return data.Data[physicalIdx*data.Dim : (physicalIdx+1)*data.Dim]
|
||||
}
|
||||
|
||||
func (data *Float16VectorFieldData) GetRow(i int) interface{} {
|
||||
return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2]
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return nil
|
||||
}
|
||||
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
|
||||
return data.Data[physicalIdx*data.Dim*2 : (physicalIdx+1)*data.Dim*2]
|
||||
}
|
||||
|
||||
func (data *BFloat16VectorFieldData) GetRow(i int) interface{} {
|
||||
return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2]
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return nil
|
||||
}
|
||||
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
|
||||
return data.Data[physicalIdx*data.Dim*2 : (physicalIdx+1)*data.Dim*2]
|
||||
}
|
||||
|
||||
func (data *Int8VectorFieldData) GetRow(i int) interface{} {
|
||||
return data.Data[i*data.Dim : (i+1)*data.Dim]
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return nil
|
||||
}
|
||||
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
|
||||
return data.Data[physicalIdx*data.Dim : (physicalIdx+1)*data.Dim]
|
||||
}
|
||||
|
||||
func (data *VectorArrayFieldData) GetRow(i int) interface{} {
|
||||
@ -862,42 +1008,83 @@ func (data *GeometryFieldData) AppendRow(row interface{}) error {
|
||||
}
|
||||
|
||||
func (data *BinaryVectorFieldData) AppendRow(row interface{}) error {
|
||||
if data.GetNullable() && row == nil {
|
||||
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, false)
|
||||
return nil
|
||||
}
|
||||
v, ok := row.([]byte)
|
||||
if !ok || len(v) != data.Dim/8 {
|
||||
return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type")
|
||||
}
|
||||
data.Data = append(data.Data, v...)
|
||||
if data.GetNullable() {
|
||||
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *FloatVectorFieldData) AppendRow(row interface{}) error {
|
||||
if data.GetNullable() && row == nil {
|
||||
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, false)
|
||||
return nil
|
||||
}
|
||||
v, ok := row.([]float32)
|
||||
if !ok || len(v) != data.Dim {
|
||||
return merr.WrapErrParameterInvalid("[]float32", row, "Wrong row type")
|
||||
}
|
||||
data.Data = append(data.Data, v...)
|
||||
if data.GetNullable() {
|
||||
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *Float16VectorFieldData) AppendRow(row interface{}) error {
|
||||
if data.GetNullable() && row == nil {
|
||||
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, false)
|
||||
return nil
|
||||
}
|
||||
v, ok := row.([]byte)
|
||||
if !ok || len(v) != data.Dim*2 {
|
||||
return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type")
|
||||
}
|
||||
data.Data = append(data.Data, v...)
|
||||
if data.GetNullable() {
|
||||
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *BFloat16VectorFieldData) AppendRow(row interface{}) error {
|
||||
if data.GetNullable() && row == nil {
|
||||
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, false)
|
||||
return nil
|
||||
}
|
||||
v, ok := row.([]byte)
|
||||
if !ok || len(v) != data.Dim*2 {
|
||||
return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type")
|
||||
}
|
||||
data.Data = append(data.Data, v...)
|
||||
if data.GetNullable() {
|
||||
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *SparseFloatVectorFieldData) AppendRow(row interface{}) error {
|
||||
if data.GetNullable() && row == nil {
|
||||
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, false)
|
||||
return nil
|
||||
}
|
||||
v, ok := row.([]byte)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("SparseFloatVectorRowData", row, "Wrong row type")
|
||||
@ -910,15 +1097,28 @@ func (data *SparseFloatVectorFieldData) AppendRow(row interface{}) error {
|
||||
data.Dim = rowDim
|
||||
}
|
||||
data.Contents = append(data.Contents, v)
|
||||
if data.GetNullable() {
|
||||
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *Int8VectorFieldData) AppendRow(row interface{}) error {
|
||||
if data.GetNullable() && row == nil {
|
||||
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, false)
|
||||
return nil
|
||||
}
|
||||
v, ok := row.([]int8)
|
||||
if !ok || len(v) != data.Dim {
|
||||
return merr.WrapErrParameterInvalid("[]int8", row, "Wrong row type")
|
||||
}
|
||||
data.Data = append(data.Data, v...)
|
||||
if data.GetNullable() {
|
||||
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
|
||||
data.ValidData = append(data.ValidData, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1431,15 +1631,15 @@ func (data *GeometryFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
|
||||
// AppendValidDataRows appends FLATTEN vectors to field data.
|
||||
func (data *BinaryVectorFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
if rows != nil {
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
if len(v) != 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
|
||||
}
|
||||
if rows == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
data.L2PMapping.Build(v, len(data.ValidData), len(v))
|
||||
data.ValidData = append(data.ValidData, v...)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1458,69 +1658,69 @@ func (data *VectorArrayFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
|
||||
// AppendValidDataRows appends FLATTEN vectors to field data.
|
||||
func (data *FloatVectorFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
if rows != nil {
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
if len(v) != 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
|
||||
}
|
||||
if rows == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
data.L2PMapping.Build(v, len(data.ValidData), len(v))
|
||||
data.ValidData = append(data.ValidData, v...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppendValidDataRows appends FLATTEN vectors to field data.
|
||||
func (data *Float16VectorFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
if rows != nil {
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
if len(v) != 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
|
||||
}
|
||||
if rows == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
data.L2PMapping.Build(v, len(data.ValidData), len(v))
|
||||
data.ValidData = append(data.ValidData, v...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppendValidDataRows appends FLATTEN vectors to field data.
|
||||
func (data *BFloat16VectorFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
if rows != nil {
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
if len(v) != 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
|
||||
}
|
||||
if rows == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
data.L2PMapping.Build(v, len(data.ValidData), len(v))
|
||||
data.ValidData = append(data.ValidData, v...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *SparseFloatVectorFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
if rows != nil {
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
if len(v) != 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
|
||||
}
|
||||
if rows == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
data.L2PMapping.Build(v, len(data.ValidData), len(v))
|
||||
data.ValidData = append(data.ValidData, v...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (data *Int8VectorFieldData) AppendValidDataRows(rows interface{}) error {
|
||||
if rows != nil {
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
if len(v) != 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
|
||||
}
|
||||
if rows == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := rows.([]bool)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
|
||||
}
|
||||
data.L2PMapping.Build(v, len(data.ValidData), len(v))
|
||||
data.ValidData = append(data.ValidData, v...)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1557,17 +1757,32 @@ func (data *TimestamptzFieldData) GetMemorySize() int {
|
||||
return binary.Size(data.Data) + binary.Size(data.ValidData) + binary.Size(data.Nullable)
|
||||
}
|
||||
|
||||
func (data *BinaryVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
|
||||
func (data *FloatVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
|
||||
func (data *Float16VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
|
||||
func (data *BFloat16VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
|
||||
func (data *BinaryVectorFieldData) GetMemorySize() int {
|
||||
// Data + ValidData + Dim(4) + Nullable(1) + L2PMapping
|
||||
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
|
||||
}
|
||||
|
||||
func (data *FloatVectorFieldData) GetMemorySize() int {
|
||||
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
|
||||
}
|
||||
|
||||
func (data *Float16VectorFieldData) GetMemorySize() int {
|
||||
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
|
||||
}
|
||||
|
||||
func (data *BFloat16VectorFieldData) GetMemorySize() int {
|
||||
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
|
||||
}
|
||||
|
||||
func (data *SparseFloatVectorFieldData) GetMemorySize() int {
|
||||
// TODO(SPARSE): should this be the memory size of serialzied size?
|
||||
return proto.Size(&data.SparseFloatArray)
|
||||
// SparseFloatArray + ValidData + Nullable(1) + L2PMapping
|
||||
return proto.Size(&data.SparseFloatArray) + binary.Size(data.ValidData) + 1 + data.L2PMapping.GetMemorySize()
|
||||
}
|
||||
|
||||
func (data *Int8VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
|
||||
func (data *Int8VectorFieldData) GetMemorySize() int {
|
||||
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
|
||||
}
|
||||
|
||||
func GetVectorSize(vector *schemapb.VectorField, vectorType schemapb.DataType) int {
|
||||
size := 0
|
||||
@ -1698,22 +1913,51 @@ func (data *GeometryFieldData) GetMemorySize() int {
|
||||
return size + binary.Size(data.ValidData) + binary.Size(data.Nullable)
|
||||
}
|
||||
|
||||
func (data *BoolFieldData) GetRowSize(i int) int { return 1 }
|
||||
func (data *Int8FieldData) GetRowSize(i int) int { return 1 }
|
||||
func (data *Int16FieldData) GetRowSize(i int) int { return 2 }
|
||||
func (data *Int32FieldData) GetRowSize(i int) int { return 4 }
|
||||
func (data *Int64FieldData) GetRowSize(i int) int { return 8 }
|
||||
func (data *FloatFieldData) GetRowSize(i int) int { return 4 }
|
||||
func (data *DoubleFieldData) GetRowSize(i int) int { return 8 }
|
||||
func (data *TimestamptzFieldData) GetRowSize(i int) int { return 8 }
|
||||
func (data *BinaryVectorFieldData) GetRowSize(i int) int { return data.Dim / 8 }
|
||||
func (data *FloatVectorFieldData) GetRowSize(i int) int { return data.Dim * 4 }
|
||||
func (data *Float16VectorFieldData) GetRowSize(i int) int { return data.Dim * 2 }
|
||||
func (data *BFloat16VectorFieldData) GetRowSize(i int) int { return data.Dim * 2 }
|
||||
func (data *Int8VectorFieldData) GetRowSize(i int) int { return data.Dim }
|
||||
func (data *StringFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
|
||||
func (data *JSONFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
|
||||
func (data *GeometryFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
|
||||
func (data *BoolFieldData) GetRowSize(i int) int { return 1 }
|
||||
func (data *Int8FieldData) GetRowSize(i int) int { return 1 }
|
||||
func (data *Int16FieldData) GetRowSize(i int) int { return 2 }
|
||||
func (data *Int32FieldData) GetRowSize(i int) int { return 4 }
|
||||
func (data *Int64FieldData) GetRowSize(i int) int { return 8 }
|
||||
func (data *FloatFieldData) GetRowSize(i int) int { return 4 }
|
||||
func (data *DoubleFieldData) GetRowSize(i int) int { return 8 }
|
||||
func (data *TimestamptzFieldData) GetRowSize(i int) int { return 8 }
|
||||
func (data *BinaryVectorFieldData) GetRowSize(i int) int {
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return 0
|
||||
}
|
||||
return data.Dim / 8
|
||||
}
|
||||
|
||||
func (data *FloatVectorFieldData) GetRowSize(i int) int {
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return 0
|
||||
}
|
||||
return data.Dim * 4
|
||||
}
|
||||
|
||||
func (data *Float16VectorFieldData) GetRowSize(i int) int {
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return 0
|
||||
}
|
||||
return data.Dim * 2
|
||||
}
|
||||
|
||||
func (data *BFloat16VectorFieldData) GetRowSize(i int) int {
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return 0
|
||||
}
|
||||
return data.Dim * 2
|
||||
}
|
||||
|
||||
func (data *Int8VectorFieldData) GetRowSize(i int) int {
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return 0
|
||||
}
|
||||
return data.Dim
|
||||
}
|
||||
func (data *StringFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
|
||||
func (data *JSONFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
|
||||
func (data *GeometryFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
|
||||
func (data *ArrayFieldData) GetRowSize(i int) int {
|
||||
switch data.ElementType {
|
||||
case schemapb.DataType_Bool:
|
||||
@ -1737,7 +1981,11 @@ func (data *ArrayFieldData) GetRowSize(i int) int {
|
||||
}
|
||||
|
||||
func (data *SparseFloatVectorFieldData) GetRowSize(i int) int {
|
||||
return len(data.Contents[i])
|
||||
if data.GetNullable() && !data.ValidData[i] {
|
||||
return 0
|
||||
}
|
||||
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
|
||||
return len(data.Contents[physicalIdx])
|
||||
}
|
||||
|
||||
func (data *VectorArrayFieldData) GetRowSize(i int) int {
|
||||
@ -1777,27 +2025,27 @@ func (data *TimestamptzFieldData) GetNullable() bool {
|
||||
}
|
||||
|
||||
func (data *BFloat16VectorFieldData) GetNullable() bool {
|
||||
return false
|
||||
return data.Nullable
|
||||
}
|
||||
|
||||
func (data *BinaryVectorFieldData) GetNullable() bool {
|
||||
return false
|
||||
return data.Nullable
|
||||
}
|
||||
|
||||
func (data *FloatVectorFieldData) GetNullable() bool {
|
||||
return false
|
||||
return data.Nullable
|
||||
}
|
||||
|
||||
func (data *SparseFloatVectorFieldData) GetNullable() bool {
|
||||
return false
|
||||
return data.Nullable
|
||||
}
|
||||
|
||||
func (data *Float16VectorFieldData) GetNullable() bool {
|
||||
return false
|
||||
return data.Nullable
|
||||
}
|
||||
|
||||
func (data *Int8VectorFieldData) GetNullable() bool {
|
||||
return false
|
||||
return data.Nullable
|
||||
}
|
||||
|
||||
func (data *StringFieldData) GetNullable() bool {
|
||||
@ -1819,3 +2067,23 @@ func (data *VectorArrayFieldData) GetNullable() bool {
|
||||
func (data *GeometryFieldData) GetNullable() bool {
|
||||
return data.Nullable
|
||||
}
|
||||
|
||||
func (data *BoolFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *Int8FieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *Int16FieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *Int32FieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *Int64FieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *FloatFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *DoubleFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *TimestamptzFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *StringFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *ArrayFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *JSONFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *GeometryFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *BinaryVectorFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *FloatVectorFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *Float16VectorFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *BFloat16VectorFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *SparseFloatVectorFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *Int8VectorFieldData) GetValidData() []bool { return data.ValidData }
|
||||
func (data *VectorArrayFieldData) GetValidData() []bool { return nil }
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
@ -45,17 +46,31 @@ func (s *InsertDataSuite) TestInsertData() {
|
||||
tests := []struct {
|
||||
description string
|
||||
dataType schemapb.DataType
|
||||
typeParams []*commonpb.KeyValuePair
|
||||
nullable bool
|
||||
}{
|
||||
{"nullable bool field", schemapb.DataType_Bool},
|
||||
{"nullable int8 field", schemapb.DataType_Int8},
|
||||
{"nullable int16 field", schemapb.DataType_Int16},
|
||||
{"nullable int32 field", schemapb.DataType_Int32},
|
||||
{"nullable int64 field", schemapb.DataType_Int64},
|
||||
{"nullable float field", schemapb.DataType_Float},
|
||||
{"nullable double field", schemapb.DataType_Double},
|
||||
{"nullable json field", schemapb.DataType_JSON},
|
||||
{"nullable array field", schemapb.DataType_Array},
|
||||
{"nullable string/varchar field", schemapb.DataType_String},
|
||||
{"nullable bool field", schemapb.DataType_Bool, nil, true},
|
||||
{"nullable int8 field", schemapb.DataType_Int8, nil, true},
|
||||
{"nullable int16 field", schemapb.DataType_Int16, nil, true},
|
||||
{"nullable int32 field", schemapb.DataType_Int32, nil, true},
|
||||
{"nullable int64 field", schemapb.DataType_Int64, nil, true},
|
||||
{"nullable float field", schemapb.DataType_Float, nil, true},
|
||||
{"nullable double field", schemapb.DataType_Double, nil, true},
|
||||
{"nullable json field", schemapb.DataType_JSON, nil, true},
|
||||
{"nullable array field", schemapb.DataType_Array, nil, true},
|
||||
{"nullable string/varchar field", schemapb.DataType_String, nil, true},
|
||||
{"nullable binary vector field", schemapb.DataType_BinaryVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}, true},
|
||||
{"nullable float vector field", schemapb.DataType_FloatVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
|
||||
{"nullable float16 vector field", schemapb.DataType_Float16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
|
||||
{"nullable bfloat16 vector field", schemapb.DataType_BFloat16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
|
||||
{"nullable sparse float vector field", schemapb.DataType_SparseFloatVector, nil, true},
|
||||
{"nullable int8 vector field", schemapb.DataType_Int8Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
|
||||
{"non-nullable binary vector field", schemapb.DataType_BinaryVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}, false},
|
||||
{"non-nullable float vector field", schemapb.DataType_FloatVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
|
||||
{"non-nullable float16 vector field", schemapb.DataType_Float16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
|
||||
{"non-nullable bfloat16 vector field", schemapb.DataType_BFloat16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
|
||||
{"non-nullable sparse float vector field", schemapb.DataType_SparseFloatVector, nil, false},
|
||||
{"non-nullable int8 vector field", schemapb.DataType_Int8Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@ -63,8 +78,9 @@ func (s *InsertDataSuite) TestInsertData() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
DataType: test.dataType,
|
||||
Nullable: true,
|
||||
DataType: test.dataType,
|
||||
Nullable: test.nullable,
|
||||
TypeParams: test.typeParams,
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -115,15 +131,15 @@ func (s *InsertDataSuite) TestInsertData() {
|
||||
s.Run("init by New", func() {
|
||||
s.True(s.iDataEmpty.IsEmpty())
|
||||
s.Equal(0, s.iDataEmpty.GetRowNum())
|
||||
s.Equal(33, s.iDataEmpty.GetMemorySize())
|
||||
s.Equal(161, s.iDataEmpty.GetMemorySize())
|
||||
|
||||
s.False(s.iDataOneRow.IsEmpty())
|
||||
s.Equal(1, s.iDataOneRow.GetRowNum())
|
||||
s.Equal(240, s.iDataOneRow.GetMemorySize())
|
||||
s.Equal(535, s.iDataOneRow.GetMemorySize())
|
||||
|
||||
s.False(s.iDataTwoRows.IsEmpty())
|
||||
s.Equal(2, s.iDataTwoRows.GetRowNum())
|
||||
s.Equal(433, s.iDataTwoRows.GetMemorySize())
|
||||
s.Equal(734, s.iDataTwoRows.GetMemorySize())
|
||||
|
||||
for _, field := range s.iDataTwoRows.Data {
|
||||
s.Equal(2, field.RowNum())
|
||||
@ -147,12 +163,13 @@ func (s *InsertDataSuite) TestMemorySize() {
|
||||
s.Equal(s.iDataEmpty.Data[DoubleField].GetMemorySize(), 1)
|
||||
s.Equal(s.iDataEmpty.Data[StringField].GetMemorySize(), 1)
|
||||
s.Equal(s.iDataEmpty.Data[ArrayField].GetMemorySize(), 1)
|
||||
s.Equal(s.iDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4)
|
||||
s.Equal(s.iDataEmpty.Data[FloatVectorField].GetMemorySize(), 4)
|
||||
s.Equal(s.iDataEmpty.Data[Float16VectorField].GetMemorySize(), 4)
|
||||
s.Equal(s.iDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4)
|
||||
s.Equal(s.iDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 0)
|
||||
s.Equal(s.iDataEmpty.Data[Int8VectorField].GetMemorySize(), 4)
|
||||
// +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8)
|
||||
s.Equal(s.iDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4+9)
|
||||
s.Equal(s.iDataEmpty.Data[FloatVectorField].GetMemorySize(), 4+9)
|
||||
s.Equal(s.iDataEmpty.Data[Float16VectorField].GetMemorySize(), 4+9)
|
||||
s.Equal(s.iDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4+9)
|
||||
s.Equal(s.iDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 0+9)
|
||||
s.Equal(s.iDataEmpty.Data[Int8VectorField].GetMemorySize(), 4+9)
|
||||
s.Equal(s.iDataEmpty.Data[StructSubInt32Field].GetMemorySize(), 1)
|
||||
s.Equal(s.iDataEmpty.Data[StructSubFloatVectorField].GetMemorySize(), 0)
|
||||
|
||||
@ -168,12 +185,13 @@ func (s *InsertDataSuite) TestMemorySize() {
|
||||
s.Equal(s.iDataOneRow.Data[StringField].GetMemorySize(), 20)
|
||||
s.Equal(s.iDataOneRow.Data[JSONField].GetMemorySize(), len([]byte(`{"batch":1}`))+16+1)
|
||||
s.Equal(s.iDataOneRow.Data[ArrayField].GetMemorySize(), 3*4+1)
|
||||
s.Equal(s.iDataOneRow.Data[BinaryVectorField].GetMemorySize(), 5)
|
||||
s.Equal(s.iDataOneRow.Data[FloatVectorField].GetMemorySize(), 20)
|
||||
s.Equal(s.iDataOneRow.Data[Float16VectorField].GetMemorySize(), 12)
|
||||
s.Equal(s.iDataOneRow.Data[BFloat16VectorField].GetMemorySize(), 12)
|
||||
s.Equal(s.iDataOneRow.Data[SparseFloatVectorField].GetMemorySize(), 28)
|
||||
s.Equal(s.iDataOneRow.Data[Int8VectorField].GetMemorySize(), 8)
|
||||
// +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8)
|
||||
s.Equal(s.iDataOneRow.Data[BinaryVectorField].GetMemorySize(), 5+9)
|
||||
s.Equal(s.iDataOneRow.Data[FloatVectorField].GetMemorySize(), 20+9)
|
||||
s.Equal(s.iDataOneRow.Data[Float16VectorField].GetMemorySize(), 12+9)
|
||||
s.Equal(s.iDataOneRow.Data[BFloat16VectorField].GetMemorySize(), 12+9)
|
||||
s.Equal(s.iDataOneRow.Data[SparseFloatVectorField].GetMemorySize(), 28+9)
|
||||
s.Equal(s.iDataOneRow.Data[Int8VectorField].GetMemorySize(), 8+9)
|
||||
s.Equal(s.iDataOneRow.Data[StructSubInt32Field].GetMemorySize(), 3*4+1)
|
||||
s.Equal(s.iDataOneRow.Data[StructSubFloatVectorField].GetMemorySize(), 3*4*2+4)
|
||||
|
||||
@ -188,12 +206,13 @@ func (s *InsertDataSuite) TestMemorySize() {
|
||||
s.Equal(s.iDataTwoRows.Data[DoubleField].GetMemorySize(), 17)
|
||||
s.Equal(s.iDataTwoRows.Data[StringField].GetMemorySize(), 39)
|
||||
s.Equal(s.iDataTwoRows.Data[ArrayField].GetMemorySize(), 25)
|
||||
s.Equal(s.iDataTwoRows.Data[BinaryVectorField].GetMemorySize(), 6)
|
||||
s.Equal(s.iDataTwoRows.Data[FloatVectorField].GetMemorySize(), 36)
|
||||
s.Equal(s.iDataTwoRows.Data[Float16VectorField].GetMemorySize(), 20)
|
||||
s.Equal(s.iDataTwoRows.Data[BFloat16VectorField].GetMemorySize(), 20)
|
||||
s.Equal(s.iDataTwoRows.Data[SparseFloatVectorField].GetMemorySize(), 54)
|
||||
s.Equal(s.iDataTwoRows.Data[Int8VectorField].GetMemorySize(), 12)
|
||||
// +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8)
|
||||
s.Equal(s.iDataTwoRows.Data[BinaryVectorField].GetMemorySize(), 6+9)
|
||||
s.Equal(s.iDataTwoRows.Data[FloatVectorField].GetMemorySize(), 36+9)
|
||||
s.Equal(s.iDataTwoRows.Data[Float16VectorField].GetMemorySize(), 20+9)
|
||||
s.Equal(s.iDataTwoRows.Data[BFloat16VectorField].GetMemorySize(), 20+9)
|
||||
s.Equal(s.iDataTwoRows.Data[SparseFloatVectorField].GetMemorySize(), 54+9)
|
||||
s.Equal(s.iDataTwoRows.Data[Int8VectorField].GetMemorySize(), 12+9)
|
||||
s.Equal(s.iDataTwoRows.Data[StructSubInt32Field].GetMemorySize(), 3*4+2*4+1)
|
||||
s.Equal(s.iDataTwoRows.Data[StructSubFloatVectorField].GetMemorySize(), 3*4*2+4+2*4*2+4)
|
||||
}
|
||||
@ -252,25 +271,31 @@ func (s *InsertDataSuite) SetupTest() {
|
||||
s.Require().NoError(err)
|
||||
s.True(s.iDataEmpty.IsEmpty())
|
||||
s.Equal(0, s.iDataEmpty.GetRowNum())
|
||||
s.Equal(33, s.iDataEmpty.GetMemorySize())
|
||||
s.Equal(161, s.iDataEmpty.GetMemorySize())
|
||||
|
||||
row1 := map[FieldID]interface{}{
|
||||
RowIDField: int64(3),
|
||||
TimestampField: int64(3),
|
||||
BoolField: true,
|
||||
Int8Field: int8(3),
|
||||
Int16Field: int16(3),
|
||||
Int32Field: int32(3),
|
||||
Int64Field: int64(3),
|
||||
FloatField: float32(3),
|
||||
DoubleField: float64(3),
|
||||
StringField: "str",
|
||||
BinaryVectorField: []byte{0},
|
||||
FloatVectorField: []float32{4, 5, 6, 7},
|
||||
Float16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
|
||||
BFloat16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
|
||||
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
|
||||
Int8VectorField: []int8{-4, -5, 6, 7},
|
||||
RowIDField: int64(3),
|
||||
TimestampField: int64(3),
|
||||
BoolField: true,
|
||||
Int8Field: int8(3),
|
||||
Int16Field: int16(3),
|
||||
Int32Field: int32(3),
|
||||
Int64Field: int64(3),
|
||||
FloatField: float32(3),
|
||||
DoubleField: float64(3),
|
||||
StringField: "str",
|
||||
BinaryVectorField: []byte{0},
|
||||
FloatVectorField: []float32{4, 5, 6, 7},
|
||||
Float16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
|
||||
BFloat16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
|
||||
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
|
||||
Int8VectorField: []int8{-4, -5, 6, 7},
|
||||
NullableFloatVectorField: []float32{1.0, 2.0, 3.0, 4.0},
|
||||
NullableBinaryVectorField: []byte{1},
|
||||
NullableFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
NullableBFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
NullableInt8VectorField: []int8{1, 2, 3, 4},
|
||||
NullableSparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
|
||||
ArrayField: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
|
||||
@ -300,22 +325,28 @@ func (s *InsertDataSuite) SetupTest() {
|
||||
}
|
||||
|
||||
row2 := map[FieldID]interface{}{
|
||||
RowIDField: int64(1),
|
||||
TimestampField: int64(1),
|
||||
BoolField: false,
|
||||
Int8Field: int8(1),
|
||||
Int16Field: int16(1),
|
||||
Int32Field: int32(1),
|
||||
Int64Field: int64(1),
|
||||
FloatField: float32(1),
|
||||
DoubleField: float64(1),
|
||||
StringField: string("str"),
|
||||
BinaryVectorField: []byte{0},
|
||||
FloatVectorField: []float32{4, 5, 6, 7},
|
||||
Float16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
BFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{2, 3, 4}, []float32{4, 5, 6}),
|
||||
Int8VectorField: []int8{-128, -5, 6, 127},
|
||||
RowIDField: int64(1),
|
||||
TimestampField: int64(1),
|
||||
BoolField: false,
|
||||
Int8Field: int8(1),
|
||||
Int16Field: int16(1),
|
||||
Int32Field: int32(1),
|
||||
Int64Field: int64(1),
|
||||
FloatField: float32(1),
|
||||
DoubleField: float64(1),
|
||||
StringField: string("str"),
|
||||
BinaryVectorField: []byte{0},
|
||||
FloatVectorField: []float32{4, 5, 6, 7},
|
||||
Float16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
BFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{2, 3, 4}, []float32{4, 5, 6}),
|
||||
Int8VectorField: []int8{-128, -5, 6, 127},
|
||||
NullableFloatVectorField: nil,
|
||||
NullableBinaryVectorField: nil,
|
||||
NullableFloat16VectorField: nil,
|
||||
NullableBFloat16VectorField: nil,
|
||||
NullableInt8VectorField: nil,
|
||||
NullableSparseFloatVectorField: nil,
|
||||
ArrayField: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
|
||||
|
||||
@ -40,12 +40,12 @@ type PayloadWriterInterface interface {
|
||||
AddOneArrayToPayload(*schemapb.ScalarField, bool) error
|
||||
AddOneJSONToPayload([]byte, bool) error
|
||||
AddOneGeometryToPayload(msg []byte, isValid bool) error
|
||||
AddBinaryVectorToPayload([]byte, int) error
|
||||
AddFloatVectorToPayload([]float32, int) error
|
||||
AddFloat16VectorToPayload([]byte, int) error
|
||||
AddBFloat16VectorToPayload([]byte, int) error
|
||||
AddBinaryVectorToPayload(data []byte, dim int, validData []bool) error
|
||||
AddFloatVectorToPayload(data []float32, dim int, validData []bool) error
|
||||
AddFloat16VectorToPayload(data []byte, dim int, validData []bool) error
|
||||
AddBFloat16VectorToPayload(data []byte, dim int, validData []bool) error
|
||||
AddSparseFloatVectorToPayload(*SparseFloatVectorFieldData) error
|
||||
AddInt8VectorToPayload([]int8, int) error
|
||||
AddInt8VectorToPayload(data []int8, dim int, validData []bool) error
|
||||
AddVectorArrayFieldDataToPayload(*VectorArrayFieldData) error
|
||||
FinishPayloadWriter() error
|
||||
GetPayloadBufferFromWriter() ([]byte, error)
|
||||
@ -72,12 +72,12 @@ type PayloadReaderInterface interface {
|
||||
GetVectorArrayFromPayload() ([]*schemapb.VectorField, error)
|
||||
GetJSONFromPayload() ([][]byte, []bool, error)
|
||||
GetGeometryFromPayload() ([][]byte, []bool, error)
|
||||
GetBinaryVectorFromPayload() ([]byte, int, error)
|
||||
GetFloat16VectorFromPayload() ([]byte, int, error)
|
||||
GetBFloat16VectorFromPayload() ([]byte, int, error)
|
||||
GetFloatVectorFromPayload() ([]float32, int, error)
|
||||
GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error)
|
||||
GetInt8VectorFromPayload() ([]int8, int, error)
|
||||
GetBinaryVectorFromPayload() ([]byte, int, []bool, int, error)
|
||||
GetFloat16VectorFromPayload() ([]byte, int, []bool, int, error)
|
||||
GetBFloat16VectorFromPayload() ([]byte, int, []bool, int, error)
|
||||
GetFloatVectorFromPayload() ([]float32, int, []bool, int, error)
|
||||
GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, []bool, error)
|
||||
GetInt8VectorFromPayload() ([]int8, int, []bool, int, error)
|
||||
GetPayloadLengthFromReader() (int, error)
|
||||
|
||||
GetByteArrayDataSet() (*DataSet[parquet.ByteArray, *file.ByteArrayColumnChunkReader], error)
|
||||
|
||||
@ -149,23 +149,23 @@ func (r *PayloadReader) GetDataFromPayload() (interface{}, []bool, int, error) {
|
||||
val, validData, err := r.GetTimestamptzFromPayload()
|
||||
return val, validData, 0, err
|
||||
case schemapb.DataType_BinaryVector:
|
||||
val, dim, err := r.GetBinaryVectorFromPayload()
|
||||
return val, nil, dim, err
|
||||
val, dim, validData, _, err := r.GetBinaryVectorFromPayload()
|
||||
return val, validData, dim, err
|
||||
case schemapb.DataType_FloatVector:
|
||||
val, dim, err := r.GetFloatVectorFromPayload()
|
||||
return val, nil, dim, err
|
||||
val, dim, validData, _, err := r.GetFloatVectorFromPayload()
|
||||
return val, validData, dim, err
|
||||
case schemapb.DataType_Float16Vector:
|
||||
val, dim, err := r.GetFloat16VectorFromPayload()
|
||||
return val, nil, dim, err
|
||||
val, dim, validData, _, err := r.GetFloat16VectorFromPayload()
|
||||
return val, validData, dim, err
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
val, dim, err := r.GetBFloat16VectorFromPayload()
|
||||
return val, nil, dim, err
|
||||
val, dim, validData, _, err := r.GetBFloat16VectorFromPayload()
|
||||
return val, validData, dim, err
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
val, dim, err := r.GetSparseFloatVectorFromPayload()
|
||||
return val, nil, dim, err
|
||||
val, dim, validData, err := r.GetSparseFloatVectorFromPayload()
|
||||
return val, validData, dim, err
|
||||
case schemapb.DataType_Int8Vector:
|
||||
val, dim, err := r.GetInt8VectorFromPayload()
|
||||
return val, nil, dim, err
|
||||
val, dim, validData, _, err := r.GetInt8VectorFromPayload()
|
||||
return val, validData, dim, err
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
val, validData, err := r.GetStringFromPayload()
|
||||
return val, validData, 0, err
|
||||
@ -681,96 +681,434 @@ func readByteAndConvert[T any](r *PayloadReader, convert func(parquet.ByteArray)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// GetBinaryVectorFromPayload returns vector, dimension, error
|
||||
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
|
||||
// GetBinaryVectorFromPayload returns vector, dimension, validData, numRows, error
|
||||
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, []bool, int, error) {
|
||||
if r.colType != schemapb.DataType_BinaryVector {
|
||||
return nil, -1, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String())
|
||||
return nil, -1, nil, 0, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String())
|
||||
}
|
||||
|
||||
if r.nullable {
|
||||
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
arrowSchema, err := fileReader.Schema()
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if arrowSchema.NumFields() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
|
||||
}
|
||||
|
||||
field := arrowSchema.Field(0)
|
||||
var dim int
|
||||
|
||||
if field.Type.ID() == arrow.BINARY {
|
||||
if !field.HasMetadata() {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable binary vector field is missing metadata")
|
||||
}
|
||||
metadata := field.Metadata
|
||||
dimStr, ok := metadata.GetValue("dim")
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable binary vector metadata missing required 'dim' field")
|
||||
}
|
||||
var err error
|
||||
dim, err = strconv.Atoi(dimStr)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
|
||||
}
|
||||
dim = dim / 8
|
||||
} else {
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
dim = col.Descriptor().TypeLength()
|
||||
}
|
||||
|
||||
table, err := fileReader.ReadTable(context.Background())
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
defer table.Release()
|
||||
|
||||
if table.NumCols() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
|
||||
}
|
||||
|
||||
column := table.Column(0)
|
||||
validCount := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
for i := 0; i < chunk.Len(); i++ {
|
||||
if chunk.IsValid(i) {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret := make([]byte, validCount*dim)
|
||||
validData := make([]bool, r.numRows)
|
||||
offset := 0
|
||||
dataIdx := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
binaryArray, ok := chunk.(*array.Binary)
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
|
||||
}
|
||||
for i := 0; i < binaryArray.Len(); i++ {
|
||||
if binaryArray.IsValid(i) {
|
||||
validData[offset+i] = true
|
||||
bytes := binaryArray.Value(i)
|
||||
copy(ret[dataIdx*dim:(dataIdx+1)*dim], bytes)
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
offset += binaryArray.Len()
|
||||
}
|
||||
|
||||
return ret, dim * 8, validData, int(r.numRows), nil
|
||||
}
|
||||
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
dim := col.Descriptor().TypeLength()
|
||||
|
||||
values := make([]parquet.FixedLenByteArray, r.numRows)
|
||||
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if valuesRead != r.numRows {
|
||||
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
}
|
||||
|
||||
ret := make([]byte, int64(dim)*r.numRows)
|
||||
for i := 0; i < int(r.numRows); i++ {
|
||||
copy(ret[i*dim:(i+1)*dim], values[i])
|
||||
}
|
||||
return ret, dim * 8, nil
|
||||
return ret, dim * 8, nil, int(r.numRows), nil
|
||||
}
|
||||
|
||||
// GetFloat16VectorFromPayload returns vector, dimension, error
|
||||
func (r *PayloadReader) GetFloat16VectorFromPayload() ([]byte, int, error) {
|
||||
// GetFloat16VectorFromPayload returns vector, dimension, validData, numRows, error
|
||||
func (r *PayloadReader) GetFloat16VectorFromPayload() ([]byte, int, []bool, int, error) {
|
||||
if r.colType != schemapb.DataType_Float16Vector {
|
||||
return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
|
||||
return nil, -1, nil, 0, fmt.Errorf("failed to get float16 vector from datatype %v", r.colType.String())
|
||||
}
|
||||
if r.nullable {
|
||||
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
arrowSchema, err := fileReader.Schema()
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if arrowSchema.NumFields() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
|
||||
}
|
||||
|
||||
field := arrowSchema.Field(0)
|
||||
var dim int
|
||||
|
||||
if field.Type.ID() == arrow.BINARY {
|
||||
if !field.HasMetadata() {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable float16 vector field is missing metadata")
|
||||
}
|
||||
metadata := field.Metadata
|
||||
dimStr, ok := metadata.GetValue("dim")
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable float16 vector metadata missing required 'dim' field")
|
||||
}
|
||||
var err error
|
||||
dim, err = strconv.Atoi(dimStr)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
|
||||
}
|
||||
} else {
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
dim = col.Descriptor().TypeLength() / 2
|
||||
}
|
||||
|
||||
table, err := fileReader.ReadTable(context.Background())
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
defer table.Release()
|
||||
|
||||
if table.NumCols() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
|
||||
}
|
||||
|
||||
column := table.Column(0)
|
||||
validCount := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
for i := 0; i < chunk.Len(); i++ {
|
||||
if chunk.IsValid(i) {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret := make([]byte, validCount*dim*2)
|
||||
validData := make([]bool, r.numRows)
|
||||
offset := 0
|
||||
dataIdx := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
binaryArray, ok := chunk.(*array.Binary)
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
|
||||
}
|
||||
for i := 0; i < binaryArray.Len(); i++ {
|
||||
if binaryArray.IsValid(i) {
|
||||
validData[offset+i] = true
|
||||
bytes := binaryArray.Value(i)
|
||||
copy(ret[dataIdx*dim*2:(dataIdx+1)*dim*2], bytes)
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
offset += binaryArray.Len()
|
||||
}
|
||||
|
||||
return ret, dim, validData, int(r.numRows), nil
|
||||
}
|
||||
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
dim := col.Descriptor().TypeLength() / 2
|
||||
|
||||
values := make([]parquet.FixedLenByteArray, r.numRows)
|
||||
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if valuesRead != r.numRows {
|
||||
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
}
|
||||
|
||||
ret := make([]byte, int64(dim*2)*r.numRows)
|
||||
for i := 0; i < int(r.numRows); i++ {
|
||||
copy(ret[i*dim*2:(i+1)*dim*2], values[i])
|
||||
}
|
||||
return ret, dim, nil
|
||||
return ret, dim, nil, int(r.numRows), nil
|
||||
}
|
||||
|
||||
// GetBFloat16VectorFromPayload returns vector, dimension, error
|
||||
func (r *PayloadReader) GetBFloat16VectorFromPayload() ([]byte, int, error) {
|
||||
// GetBFloat16VectorFromPayload returns vector, dimension, validData, numRows, error
|
||||
func (r *PayloadReader) GetBFloat16VectorFromPayload() ([]byte, int, []bool, int, error) {
|
||||
if r.colType != schemapb.DataType_BFloat16Vector {
|
||||
return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
|
||||
return nil, -1, nil, 0, fmt.Errorf("failed to get bfloat16 vector from datatype %v", r.colType.String())
|
||||
}
|
||||
if r.nullable {
|
||||
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
arrowSchema, err := fileReader.Schema()
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if arrowSchema.NumFields() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
|
||||
}
|
||||
|
||||
field := arrowSchema.Field(0)
|
||||
var dim int
|
||||
|
||||
if field.Type.ID() == arrow.BINARY {
|
||||
if !field.HasMetadata() {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable bfloat16 vector field is missing metadata")
|
||||
}
|
||||
metadata := field.Metadata
|
||||
dimStr, ok := metadata.GetValue("dim")
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable bfloat16 vector metadata missing required 'dim' field")
|
||||
}
|
||||
var err error
|
||||
dim, err = strconv.Atoi(dimStr)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
|
||||
}
|
||||
} else {
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
dim = col.Descriptor().TypeLength() / 2
|
||||
}
|
||||
|
||||
table, err := fileReader.ReadTable(context.Background())
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
defer table.Release()
|
||||
|
||||
if table.NumCols() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
|
||||
}
|
||||
|
||||
column := table.Column(0)
|
||||
validCount := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
for i := 0; i < chunk.Len(); i++ {
|
||||
if chunk.IsValid(i) {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret := make([]byte, validCount*dim*2)
|
||||
validData := make([]bool, r.numRows)
|
||||
offset := 0
|
||||
dataIdx := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
binaryArray, ok := chunk.(*array.Binary)
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
|
||||
}
|
||||
for i := 0; i < binaryArray.Len(); i++ {
|
||||
if binaryArray.IsValid(i) {
|
||||
validData[offset+i] = true
|
||||
bytes := binaryArray.Value(i)
|
||||
copy(ret[dataIdx*dim*2:(dataIdx+1)*dim*2], bytes)
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
offset += binaryArray.Len()
|
||||
}
|
||||
|
||||
return ret, dim, validData, int(r.numRows), nil
|
||||
}
|
||||
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
dim := col.Descriptor().TypeLength() / 2
|
||||
|
||||
values := make([]parquet.FixedLenByteArray, r.numRows)
|
||||
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if valuesRead != r.numRows {
|
||||
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
}
|
||||
|
||||
ret := make([]byte, int64(dim*2)*r.numRows)
|
||||
for i := 0; i < int(r.numRows); i++ {
|
||||
copy(ret[i*dim*2:(i+1)*dim*2], values[i])
|
||||
}
|
||||
return ret, dim, nil
|
||||
return ret, dim, nil, int(r.numRows), nil
|
||||
}
|
||||
|
||||
// GetFloatVectorFromPayload returns vector, dimension, error
|
||||
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
|
||||
// GetFloatVectorFromPayload returns vector, dimension, validData, numRows, error
|
||||
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, []bool, int, error) {
|
||||
if r.colType != schemapb.DataType_FloatVector {
|
||||
return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
|
||||
return nil, -1, nil, 0, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
|
||||
}
|
||||
if r.nullable {
|
||||
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
arrowSchema, err := fileReader.Schema()
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if arrowSchema.NumFields() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
|
||||
}
|
||||
|
||||
field := arrowSchema.Field(0)
|
||||
var dim int
|
||||
|
||||
if field.Type.ID() == arrow.BINARY {
|
||||
if !field.HasMetadata() {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable float vector field is missing metadata")
|
||||
}
|
||||
metadata := field.Metadata
|
||||
dimStr, ok := metadata.GetValue("dim")
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable float vector metadata missing required 'dim' field")
|
||||
}
|
||||
var err error
|
||||
dim, err = strconv.Atoi(dimStr)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
|
||||
}
|
||||
} else {
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
dim = col.Descriptor().TypeLength() / 4
|
||||
}
|
||||
|
||||
table, err := fileReader.ReadTable(context.Background())
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
defer table.Release()
|
||||
|
||||
if table.NumCols() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
|
||||
}
|
||||
|
||||
column := table.Column(0)
|
||||
validCount := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
for i := 0; i < chunk.Len(); i++ {
|
||||
if chunk.IsValid(i) {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret := make([]float32, validCount*dim)
|
||||
validData := make([]bool, r.numRows)
|
||||
offset := 0
|
||||
dataIdx := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
binaryArray, ok := chunk.(*array.Binary)
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
|
||||
}
|
||||
for i := 0; i < binaryArray.Len(); i++ {
|
||||
if binaryArray.IsValid(i) {
|
||||
validData[offset+i] = true
|
||||
bytes := binaryArray.Value(i)
|
||||
copy(arrow.Float32Traits.CastToBytes(ret[dataIdx*dim:(dataIdx+1)*dim]), bytes)
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
offset += binaryArray.Len()
|
||||
}
|
||||
|
||||
return ret, dim, validData, int(r.numRows), nil
|
||||
}
|
||||
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
dim := col.Descriptor().TypeLength() / 4
|
||||
@ -778,38 +1116,89 @@ func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
|
||||
values := make([]parquet.FixedLenByteArray, r.numRows)
|
||||
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if valuesRead != r.numRows {
|
||||
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
}
|
||||
|
||||
ret := make([]float32, int64(dim)*r.numRows)
|
||||
for i := 0; i < int(r.numRows); i++ {
|
||||
copy(arrow.Float32Traits.CastToBytes(ret[i*dim:(i+1)*dim]), values[i])
|
||||
}
|
||||
return ret, dim, nil
|
||||
return ret, dim, nil, int(r.numRows), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error) {
|
||||
// GetSparseFloatVectorFromPayload returns fieldData, dimension, validData, error
|
||||
func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, []bool, error) {
|
||||
if !typeutil.IsSparseFloatVectorType(r.colType) {
|
||||
return nil, -1, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String())
|
||||
return nil, -1, nil, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String())
|
||||
}
|
||||
|
||||
if r.nullable {
|
||||
fieldData := &SparseFloatVectorFieldData{}
|
||||
validData := make([]bool, r.numRows)
|
||||
|
||||
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
|
||||
if err != nil {
|
||||
return nil, -1, nil, err
|
||||
}
|
||||
|
||||
table, err := fileReader.ReadTable(context.Background())
|
||||
if err != nil {
|
||||
return nil, -1, nil, err
|
||||
}
|
||||
defer table.Release()
|
||||
|
||||
if table.NumCols() != 1 {
|
||||
return nil, -1, nil, fmt.Errorf("expected 1 column, got %d", table.NumCols())
|
||||
}
|
||||
|
||||
column := table.Column(0)
|
||||
offset := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
binaryArray, ok := chunk.(*array.Binary)
|
||||
if !ok {
|
||||
return nil, -1, nil, fmt.Errorf("expected Binary array, got %T", chunk)
|
||||
}
|
||||
|
||||
for i := 0; i < binaryArray.Len(); i++ {
|
||||
validData[offset+i] = binaryArray.IsValid(i)
|
||||
if validData[offset+i] {
|
||||
value := binaryArray.Value(i)
|
||||
if len(value)%8 != 0 {
|
||||
return nil, -1, nil, errors.New("invalid bytesData length")
|
||||
}
|
||||
fieldData.Contents = append(fieldData.Contents, value)
|
||||
rowDim := typeutil.SparseFloatRowDim(value)
|
||||
if rowDim > fieldData.Dim {
|
||||
fieldData.Dim = rowDim
|
||||
}
|
||||
} else {
|
||||
fieldData.Contents = append(fieldData.Contents, nil)
|
||||
}
|
||||
}
|
||||
offset += binaryArray.Len()
|
||||
}
|
||||
|
||||
return fieldData, int(fieldData.Dim), validData, nil
|
||||
}
|
||||
|
||||
values := make([]parquet.ByteArray, r.numRows)
|
||||
valuesRead, err := ReadDataFromAllRowGroups[parquet.ByteArray, *file.ByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, err
|
||||
}
|
||||
if valuesRead != r.numRows {
|
||||
return nil, -1, fmt.Errorf("expect %d binary, but got = %d", r.numRows, valuesRead)
|
||||
return nil, -1, nil, fmt.Errorf("expect %d binary, but got = %d", r.numRows, valuesRead)
|
||||
}
|
||||
|
||||
fieldData := &SparseFloatVectorFieldData{}
|
||||
|
||||
for _, value := range values {
|
||||
if len(value)%8 != 0 {
|
||||
return nil, -1, errors.New("invalid bytesData length")
|
||||
return nil, -1, nil, errors.New("invalid bytesData length")
|
||||
}
|
||||
|
||||
fieldData.Contents = append(fieldData.Contents, value)
|
||||
@ -819,17 +1208,101 @@ func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFie
|
||||
}
|
||||
}
|
||||
|
||||
return fieldData, int(fieldData.Dim), nil
|
||||
return fieldData, int(fieldData.Dim), nil, nil
|
||||
}
|
||||
|
||||
// GetInt8VectorFromPayload returns vector, dimension, error
|
||||
func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) {
|
||||
// GetInt8VectorFromPayload returns vector, dimension, validData, numRows, error
|
||||
func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, []bool, int, error) {
|
||||
if r.colType != schemapb.DataType_Int8Vector {
|
||||
return nil, -1, fmt.Errorf("failed to get int8 vector from datatype %v", r.colType.String())
|
||||
return nil, -1, nil, 0, fmt.Errorf("failed to get int8 vector from datatype %v", r.colType.String())
|
||||
}
|
||||
if r.nullable {
|
||||
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
arrowSchema, err := fileReader.Schema()
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if arrowSchema.NumFields() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
|
||||
}
|
||||
|
||||
field := arrowSchema.Field(0)
|
||||
var dim int
|
||||
|
||||
if field.Type.ID() == arrow.BINARY {
|
||||
if !field.HasMetadata() {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable int8 vector field is missing metadata")
|
||||
}
|
||||
metadata := field.Metadata
|
||||
dimStr, ok := metadata.GetValue("dim")
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("nullable int8 vector metadata missing required 'dim' field")
|
||||
}
|
||||
var err error
|
||||
dim, err = strconv.Atoi(dimStr)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
|
||||
}
|
||||
} else {
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
dim = col.Descriptor().TypeLength()
|
||||
}
|
||||
|
||||
table, err := fileReader.ReadTable(context.Background())
|
||||
if err != nil {
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
defer table.Release()
|
||||
|
||||
if table.NumCols() != 1 {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
|
||||
}
|
||||
|
||||
column := table.Column(0)
|
||||
validCount := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
for i := 0; i < chunk.Len(); i++ {
|
||||
if chunk.IsValid(i) {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret := make([]int8, validCount*dim)
|
||||
validData := make([]bool, r.numRows)
|
||||
offset := 0
|
||||
dataIdx := 0
|
||||
for _, chunk := range column.Data().Chunks() {
|
||||
binaryArray, ok := chunk.(*array.Binary)
|
||||
if !ok {
|
||||
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
|
||||
}
|
||||
for i := 0; i < binaryArray.Len(); i++ {
|
||||
if binaryArray.IsValid(i) {
|
||||
validData[offset+i] = true
|
||||
bytes := binaryArray.Value(i)
|
||||
int8Vals := arrow.Int8Traits.CastFromBytes(bytes)
|
||||
copy(ret[dataIdx*dim:(dataIdx+1)*dim], int8Vals)
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
offset += binaryArray.Len()
|
||||
}
|
||||
|
||||
return ret, dim, validData, int(r.numRows), nil
|
||||
}
|
||||
|
||||
col, err := r.reader.RowGroup(0).Column(0)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
dim := col.Descriptor().TypeLength()
|
||||
@ -837,11 +1310,11 @@ func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) {
|
||||
values := make([]parquet.FixedLenByteArray, r.numRows)
|
||||
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
|
||||
if err != nil {
|
||||
return nil, -1, err
|
||||
return nil, -1, nil, 0, err
|
||||
}
|
||||
|
||||
if valuesRead != r.numRows {
|
||||
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
|
||||
}
|
||||
|
||||
ret := make([]int8, int64(dim)*r.numRows)
|
||||
@ -849,7 +1322,7 @@ func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) {
|
||||
int8Vals := arrow.Int8Traits.CastFromBytes(values[i])
|
||||
copy(ret[i*dim:(i+1)*dim], int8Vals)
|
||||
}
|
||||
return ret, dim, nil
|
||||
return ret, dim, nil, int(r.numRows), nil
|
||||
}
|
||||
|
||||
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {
|
||||
|
||||
@ -484,7 +484,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
in2[i] = 1
|
||||
}
|
||||
|
||||
err = w.AddBinaryVectorToPayload(in, 8)
|
||||
err = w.AddBinaryVectorToPayload(in, 8, nil)
|
||||
assert.NoError(t, err)
|
||||
err = w.AddDataToPayloadForUT(in2, nil)
|
||||
assert.NoError(t, err)
|
||||
@ -505,7 +505,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, length, 24)
|
||||
|
||||
binVecs, dim, err := r.GetBinaryVectorFromPayload()
|
||||
binVecs, dim, _, _, err := r.GetBinaryVectorFromPayload()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 8, dim)
|
||||
assert.Equal(t, 24, len(binVecs))
|
||||
@ -524,7 +524,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1, nil)
|
||||
assert.NoError(t, err)
|
||||
err = w.AddDataToPayloadForUT([]float32{3.0, 4.0}, nil)
|
||||
assert.NoError(t, err)
|
||||
@ -545,7 +545,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, length, 4)
|
||||
|
||||
floatVecs, dim, err := r.GetFloatVectorFromPayload()
|
||||
floatVecs, dim, _, _, err := r.GetFloatVectorFromPayload()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(floatVecs))
|
||||
@ -566,7 +566,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloat16VectorToPayload([]byte{1, 2}, 1)
|
||||
err = w.AddFloat16VectorToPayload([]byte{1, 2}, 1, nil)
|
||||
assert.NoError(t, err)
|
||||
err = w.AddDataToPayloadForUT([]byte{3, 4}, nil)
|
||||
assert.NoError(t, err)
|
||||
@ -587,7 +587,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, length, 2)
|
||||
|
||||
float16Vecs, dim, err := r.GetFloat16VectorFromPayload()
|
||||
float16Vecs, dim, _, _, err := r.GetFloat16VectorFromPayload()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(float16Vecs))
|
||||
@ -608,7 +608,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1, 2}, 1)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1, 2}, 1, nil)
|
||||
assert.NoError(t, err)
|
||||
err = w.AddDataToPayloadForUT([]byte{3, 4}, nil)
|
||||
assert.NoError(t, err)
|
||||
@ -629,7 +629,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, length, 2)
|
||||
|
||||
bfloat16Vecs, dim, err := r.GetBFloat16VectorFromPayload()
|
||||
bfloat16Vecs, dim, _, _, err := r.GetBFloat16VectorFromPayload()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, dim)
|
||||
assert.Equal(t, 4, len(bfloat16Vecs))
|
||||
@ -689,7 +689,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, length, 6)
|
||||
|
||||
floatVecs, dim, err := r.GetSparseFloatVectorFromPayload()
|
||||
floatVecs, dim, _, err := r.GetSparseFloatVectorFromPayload()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 600, dim)
|
||||
assert.Equal(t, 6, len(floatVecs.Contents))
|
||||
@ -743,7 +743,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, length, 3)
|
||||
|
||||
floatVecs, dim, err := r.GetSparseFloatVectorFromPayload()
|
||||
floatVecs, dim, _, err := r.GetSparseFloatVectorFromPayload()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, actualDim, dim)
|
||||
assert.Equal(t, 3, len(floatVecs.Contents))
|
||||
@ -951,16 +951,16 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.AddBinaryVectorToPayload([]byte{}, 8)
|
||||
err = w.AddBinaryVectorToPayload([]byte{}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.AddBinaryVectorToPayload([]byte{1}, 0)
|
||||
err = w.AddBinaryVectorToPayload([]byte{1}, 0, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
|
||||
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Error(t, err)
|
||||
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
|
||||
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("TestAddFloatVectorAfterFinish", func(t *testing.T) {
|
||||
@ -972,16 +972,16 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.AddFloatVectorToPayload([]float32{}, 8)
|
||||
err = w.AddFloatVectorToPayload([]float32{}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0}, 0)
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0}, 0, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8)
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Error(t, err)
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8)
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("TestAddFloat16VectorAfterFinish", func(t *testing.T) {
|
||||
@ -990,22 +990,22 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
require.NotNil(t, w)
|
||||
defer w.Close()
|
||||
|
||||
err = w.AddFloat16VectorToPayload([]byte{}, 8)
|
||||
err = w.AddFloat16VectorToPayload([]byte{}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.AddFloat16VectorToPayload([]byte{}, 8)
|
||||
err = w.AddFloat16VectorToPayload([]byte{}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.AddFloat16VectorToPayload([]byte{1}, 0)
|
||||
err = w.AddFloat16VectorToPayload([]byte{1}, 0, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
|
||||
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Error(t, err)
|
||||
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
|
||||
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("TestAddBFloat16VectorAfterFinish", func(t *testing.T) {
|
||||
@ -1014,22 +1014,22 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
require.NotNil(t, w)
|
||||
defer w.Close()
|
||||
|
||||
err = w.AddBFloat16VectorToPayload([]byte{}, 8)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.AddBFloat16VectorToPayload([]byte{}, 8)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1}, 0)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1}, 0, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
err = w.FinishPayloadWriter()
|
||||
assert.Error(t, err)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
|
||||
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("TestAddSparseFloatVectorAfterFinish", func(t *testing.T) {
|
||||
@ -1481,11 +1481,11 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
r, err := NewPayloadReader(schemapb.DataType_BinaryVector, buffer, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _, err = r.GetBinaryVectorFromPayload()
|
||||
_, _, _, _, err = r.GetBinaryVectorFromPayload()
|
||||
assert.Error(t, err)
|
||||
|
||||
r.colType = 999
|
||||
_, _, err = r.GetBinaryVectorFromPayload()
|
||||
_, _, _, _, err = r.GetBinaryVectorFromPayload()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("TestGetBinaryVectorError2", func(t *testing.T) {
|
||||
@ -1493,7 +1493,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
|
||||
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
@ -1506,7 +1506,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
r.numRows = 99
|
||||
_, _, err = r.GetBinaryVectorFromPayload()
|
||||
_, _, _, _, err = r.GetBinaryVectorFromPayload()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("TestGetFloatVectorError", func(t *testing.T) {
|
||||
@ -1526,11 +1526,11 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
r, err := NewPayloadReader(schemapb.DataType_FloatVector, buffer, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _, err = r.GetFloatVectorFromPayload()
|
||||
_, _, _, _, err = r.GetFloatVectorFromPayload()
|
||||
assert.Error(t, err)
|
||||
|
||||
r.colType = 999
|
||||
_, _, err = r.GetFloatVectorFromPayload()
|
||||
_, _, _, _, err = r.GetFloatVectorFromPayload()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("TestGetFloatVectorError2", func(t *testing.T) {
|
||||
@ -1538,7 +1538,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8)
|
||||
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
@ -1551,7 +1551,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
r.numRows = 99
|
||||
_, _, err = r.GetFloatVectorFromPayload()
|
||||
_, _, _, _, err = r.GetFloatVectorFromPayload()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@ -1599,7 +1599,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_FloatVector)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.AddFloatVectorToPayload(vec, 128)
|
||||
err = w.AddFloatVectorToPayload(vec, 128, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
@ -2234,19 +2234,548 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) {
|
||||
w.ReleasePayloadWriter()
|
||||
})
|
||||
|
||||
t.Run("TestBinaryVector", func(t *testing.T) {
|
||||
_, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithNullable(true), WithDim(8))
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
t.Run("TestFloatVector", func(t *testing.T) {
|
||||
dim := 128
|
||||
numRows := 100
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
validDataSetup func([]bool) int
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "half null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
validCount := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if i%2 == 0 {
|
||||
validData[i] = true
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
return validCount
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all valid",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
for i := 0; i < numRows; i++ {
|
||||
validData[i] = true
|
||||
}
|
||||
return numRows
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
return 0
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(dim), WithNullable(true))
|
||||
require.NoError(t, err)
|
||||
|
||||
validData := make([]bool, numRows)
|
||||
validCount := tc.validDataSetup(validData)
|
||||
|
||||
data := make([]float32, validCount*dim)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if validData[i] {
|
||||
for j := 0; j < dim; j++ {
|
||||
data[dataIdx*dim+j] = float32(i*100 + j)
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
|
||||
err = w.AddFloatVectorToPayload(data, dim, validData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_FloatVector, buffer, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
readData, readDim, readValid, readNumRows, err := r.GetFloatVectorFromPayload()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dim, readDim)
|
||||
require.Equal(t, numRows, readNumRows)
|
||||
require.Equal(t, numRows, len(readValid))
|
||||
|
||||
dataIdx = 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
require.Equal(t, validData[i], readValid[i])
|
||||
if validData[i] {
|
||||
pos := dataIdx
|
||||
for j := 0; j < dim; j++ {
|
||||
require.Equal(t, data[dataIdx*dim+j], readData[pos*dim+j])
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestFloatVector", func(t *testing.T) {
|
||||
_, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithNullable(true), WithDim(1))
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
t.Run("TestBinaryVector", func(t *testing.T) {
|
||||
dim := 128
|
||||
numRows := 100
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
validDataSetup func([]bool) int
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "partial null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
validCount := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if i%3 == 0 {
|
||||
validData[i] = true
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
return validCount
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all valid",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
for i := 0; i < numRows; i++ {
|
||||
validData[i] = true
|
||||
}
|
||||
return numRows
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
return 0
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(dim), WithNullable(true))
|
||||
require.NoError(t, err)
|
||||
|
||||
validData := make([]bool, numRows)
|
||||
validCount := tc.validDataSetup(validData)
|
||||
|
||||
data := make([]byte, validCount*dim/8)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if validData[i] {
|
||||
for j := 0; j < dim/8; j++ {
|
||||
data[dataIdx*dim/8+j] = byte(i + j)
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
|
||||
err = w.AddBinaryVectorToPayload(data, dim, validData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_BinaryVector, buffer, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
readData, readDim, readValid, readNumRows, err := r.GetBinaryVectorFromPayload()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dim, readDim)
|
||||
require.Equal(t, numRows, readNumRows)
|
||||
require.Equal(t, numRows, len(readValid))
|
||||
|
||||
dataIdx = 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
require.Equal(t, validData[i], readValid[i])
|
||||
if validData[i] {
|
||||
pos := dataIdx
|
||||
for j := 0; j < dim/8; j++ {
|
||||
require.Equal(t, data[dataIdx*dim/8+j], readData[pos*dim/8+j])
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestFloat16Vector", func(t *testing.T) {
|
||||
_, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithNullable(true), WithDim(1))
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
dim := 128
|
||||
numRows := 100
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
validDataSetup func([]bool) int
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "partial null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
validCount := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if i%2 == 1 {
|
||||
validData[i] = true
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
return validCount
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all valid",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
for i := 0; i < numRows; i++ {
|
||||
validData[i] = true
|
||||
}
|
||||
return numRows
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
return 0
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithDim(dim), WithNullable(true))
|
||||
require.NoError(t, err)
|
||||
|
||||
validData := make([]bool, numRows)
|
||||
validCount := tc.validDataSetup(validData)
|
||||
|
||||
data := make([]byte, validCount*dim*2)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if validData[i] {
|
||||
for j := 0; j < dim*2; j++ {
|
||||
data[dataIdx*dim*2+j] = byte((i*10 + j) % 256)
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
|
||||
err = w.AddFloat16VectorToPayload(data, dim, validData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_Float16Vector, buffer, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
readData, readDim, readValid, readNumRows, err := r.GetFloat16VectorFromPayload()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dim, readDim)
|
||||
require.Equal(t, numRows, readNumRows)
|
||||
require.Equal(t, numRows, len(readValid))
|
||||
|
||||
dataIdx = 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
require.Equal(t, validData[i], readValid[i])
|
||||
if validData[i] {
|
||||
pos := dataIdx
|
||||
for j := 0; j < dim*2; j++ {
|
||||
require.Equal(t, data[dataIdx*dim*2+j], readData[pos*dim*2+j])
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestBFloat16Vector", func(t *testing.T) {
|
||||
dim := 128
|
||||
numRows := 100
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
validDataSetup func([]bool) int
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "partial null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
validCount := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if (i+1)%3 != 0 {
|
||||
validData[i] = true
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
return validCount
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all valid",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
for i := 0; i < numRows; i++ {
|
||||
validData[i] = true
|
||||
}
|
||||
return numRows
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
return 0
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, WithDim(dim), WithNullable(true))
|
||||
require.NoError(t, err)
|
||||
|
||||
validData := make([]bool, numRows)
|
||||
validCount := tc.validDataSetup(validData)
|
||||
|
||||
data := make([]byte, validCount*dim*2)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if validData[i] {
|
||||
for j := 0; j < dim*2; j++ {
|
||||
data[dataIdx*dim*2+j] = byte((i*20 + j) % 256)
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
|
||||
err = w.AddBFloat16VectorToPayload(data, dim, validData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_BFloat16Vector, buffer, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
readData, readDim, readValid, readNumRows, err := r.GetBFloat16VectorFromPayload()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dim, readDim)
|
||||
require.Equal(t, numRows, readNumRows)
|
||||
require.Equal(t, numRows, len(readValid))
|
||||
|
||||
dataIdx = 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
require.Equal(t, validData[i], readValid[i])
|
||||
if validData[i] {
|
||||
pos := dataIdx
|
||||
for j := 0; j < dim*2; j++ {
|
||||
require.Equal(t, data[dataIdx*dim*2+j], readData[pos*dim*2+j])
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestInt8Vector", func(t *testing.T) {
|
||||
dim := 128
|
||||
numRows := 100
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
validDataSetup func([]bool) int
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "partial null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
validCount := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if i < numRows/2 {
|
||||
validData[i] = true
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
return validCount
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all valid",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
for i := 0; i < numRows; i++ {
|
||||
validData[i] = true
|
||||
}
|
||||
return numRows
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
return 0
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_Int8Vector, WithDim(dim), WithNullable(true))
|
||||
require.NoError(t, err)
|
||||
|
||||
validData := make([]bool, numRows)
|
||||
validCount := tc.validDataSetup(validData)
|
||||
|
||||
data := make([]int8, validCount*dim)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if validData[i] {
|
||||
for j := 0; j < dim; j++ {
|
||||
data[dataIdx*dim+j] = int8((i*10 + j) % 128)
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
|
||||
err = w.AddInt8VectorToPayload(data, dim, validData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_Int8Vector, buffer, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
readData, readDim, readValid, readNumRows, err := r.GetInt8VectorFromPayload()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dim, readDim)
|
||||
require.Equal(t, numRows, readNumRows)
|
||||
require.Equal(t, numRows, len(readValid))
|
||||
|
||||
dataIdx = 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
require.Equal(t, validData[i], readValid[i])
|
||||
if validData[i] {
|
||||
pos := dataIdx
|
||||
for j := 0; j < dim; j++ {
|
||||
require.Equal(t, data[dataIdx*dim+j], readData[pos*dim+j])
|
||||
}
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestSparseFloatVector", func(t *testing.T) {
|
||||
numRows := 100
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
validDataSetup func([]bool) int
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "half null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
validCount := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if i%2 == 0 {
|
||||
validData[i] = true
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
return validCount
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all valid",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
for i := 0; i < numRows; i++ {
|
||||
validData[i] = true
|
||||
}
|
||||
return numRows
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all null",
|
||||
validDataSetup: func(validData []bool) int {
|
||||
return 0
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, WithNullable(true))
|
||||
require.NoError(t, err)
|
||||
|
||||
validData := make([]bool, numRows)
|
||||
tc.validDataSetup(validData)
|
||||
|
||||
data := &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 100,
|
||||
},
|
||||
ValidData: validData,
|
||||
}
|
||||
for i := 0; i < numRows; i++ {
|
||||
if validData[i] {
|
||||
sparseVec := make([]byte, 16)
|
||||
for j := 0; j < 16; j++ {
|
||||
sparseVec[j] = byte((i*10 + j) % 256)
|
||||
}
|
||||
data.SparseFloatArray.Contents = append(data.SparseFloatArray.Contents, sparseVec)
|
||||
}
|
||||
}
|
||||
|
||||
err = w.AddSparseFloatVectorToPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = w.FinishPayloadWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
buffer, err := w.GetPayloadBufferFromWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
r, err := NewPayloadReader(schemapb.DataType_SparseFloatVector, buffer, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
readData, _, readValid, err := r.GetSparseFloatVectorFromPayload()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, numRows, len(readValid))
|
||||
require.Equal(t, numRows, len(readData.Contents))
|
||||
|
||||
for i := 0; i < numRows; i++ {
|
||||
require.Equal(t, validData[i], readValid[i])
|
||||
if validData[i] {
|
||||
require.NotNil(t, readData.Contents[i])
|
||||
require.Equal(t, 16, len(readData.Contents[i]))
|
||||
for j := 0; j < 16; j++ {
|
||||
require.Equal(t, byte((i*10+j)%256), readData.Contents[i][j])
|
||||
}
|
||||
} else {
|
||||
require.Nil(t, readData.Contents[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TestAddBool with wrong valids", func(t *testing.T) {
|
||||
|
||||
@ -103,9 +103,6 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions
|
||||
if w.dim.IsNull() {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers")
|
||||
}
|
||||
if w.nullable {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("vector type does not support nullable")
|
||||
}
|
||||
} else {
|
||||
w.dim = NewNullableInt(1)
|
||||
}
|
||||
@ -125,8 +122,13 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions
|
||||
w.arrowType = arrow.ListOf(elemType)
|
||||
w.builder = array.NewListBuilder(memory.DefaultAllocator, elemType)
|
||||
} else {
|
||||
w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value)
|
||||
w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType)
|
||||
if w.nullable && typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) {
|
||||
w.arrowType = &arrow.BinaryType{}
|
||||
w.builder = array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary)
|
||||
} else {
|
||||
w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value)
|
||||
w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType)
|
||||
}
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
@ -262,25 +264,25 @@ func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
||||
}
|
||||
return w.AddBinaryVectorToPayload(val, w.dim.GetValue())
|
||||
return w.AddBinaryVectorToPayload(val, w.dim.GetValue(), validData)
|
||||
case schemapb.DataType_FloatVector:
|
||||
val, ok := data.([]float32)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
||||
}
|
||||
return w.AddFloatVectorToPayload(val, w.dim.GetValue())
|
||||
return w.AddFloatVectorToPayload(val, w.dim.GetValue(), validData)
|
||||
case schemapb.DataType_Float16Vector:
|
||||
val, ok := data.([]byte)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
||||
}
|
||||
return w.AddFloat16VectorToPayload(val, w.dim.GetValue())
|
||||
return w.AddFloat16VectorToPayload(val, w.dim.GetValue(), validData)
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
val, ok := data.([]byte)
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
||||
}
|
||||
return w.AddBFloat16VectorToPayload(val, w.dim.GetValue())
|
||||
return w.AddBFloat16VectorToPayload(val, w.dim.GetValue(), validData)
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
val, ok := data.(*SparseFloatVectorFieldData)
|
||||
if !ok {
|
||||
@ -292,7 +294,7 @@ func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
||||
}
|
||||
return w.AddInt8VectorToPayload(val, w.dim.GetValue())
|
||||
return w.AddInt8VectorToPayload(val, w.dim.GetValue(), validData)
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
val, ok := data.(*VectorArrayFieldData)
|
||||
if !ok {
|
||||
@ -660,106 +662,262 @@ func (w *NativePayloadWriter) AddOneGeometryToPayload(data []byte, isValid bool)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int) error {
|
||||
func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int, validData []bool) error {
|
||||
if w.finished {
|
||||
return errors.New("can't append data to finished binary vector payload")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into binary vector payload")
|
||||
}
|
||||
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast BinaryVectorBuilder")
|
||||
}
|
||||
|
||||
byteLength := dim / 8
|
||||
length := len(data) / byteLength
|
||||
builder.Reserve(length)
|
||||
for i := 0; i < length; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
var numRows int
|
||||
if w.nullable && len(validData) > 0 {
|
||||
numRows = len(validData)
|
||||
validCount := 0
|
||||
for _, valid := range validData {
|
||||
if valid {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
expectedDataLen := validCount * byteLength
|
||||
if len(data) != expectedDataLen {
|
||||
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into binary vector payload")
|
||||
}
|
||||
numRows = len(data) / byteLength
|
||||
if !w.nullable && len(validData) != 0 {
|
||||
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
if w.nullable {
|
||||
builder, ok := w.builder.(*array.BinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to BinaryBuilder for nullable BinaryVector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if len(validData) > 0 && !validData[i] {
|
||||
builder.AppendNull()
|
||||
} else {
|
||||
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BinaryVector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int) error {
|
||||
func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int, validData []bool) error {
|
||||
if w.finished {
|
||||
return errors.New("can't append data to finished float vector payload")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into float vector payload")
|
||||
}
|
||||
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast FloatVectorBuilder")
|
||||
var numRows int
|
||||
if w.nullable && len(validData) > 0 {
|
||||
numRows = len(validData)
|
||||
validCount := 0
|
||||
for _, valid := range validData {
|
||||
if valid {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
expectedDataLen := validCount * dim
|
||||
if len(data) != expectedDataLen {
|
||||
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into float vector payload")
|
||||
}
|
||||
numRows = len(data) / dim
|
||||
if !w.nullable && len(validData) != 0 {
|
||||
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
byteLength := dim * 4
|
||||
length := len(data) / dim
|
||||
|
||||
builder.Reserve(length)
|
||||
bytesData := make([]byte, byteLength)
|
||||
for i := 0; i < length; i++ {
|
||||
vec := data[i*dim : (i+1)*dim]
|
||||
for j := range vec {
|
||||
bytes := math.Float32bits(vec[j])
|
||||
common.Endian.PutUint32(bytesData[j*4:], bytes)
|
||||
if w.nullable {
|
||||
builder, ok := w.builder.(*array.BinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to BinaryBuilder for nullable FloatVector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
bytesData := make([]byte, byteLength)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if len(validData) > 0 && !validData[i] {
|
||||
builder.AppendNull()
|
||||
} else {
|
||||
vec := data[dataIdx*dim : (dataIdx+1)*dim]
|
||||
for j := range vec {
|
||||
bytes := math.Float32bits(vec[j])
|
||||
common.Endian.PutUint32(bytesData[j*4:], bytes)
|
||||
}
|
||||
builder.Append(bytesData)
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable FloatVector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
bytesData := make([]byte, byteLength)
|
||||
for i := 0; i < numRows; i++ {
|
||||
vec := data[i*dim : (i+1)*dim]
|
||||
for j := range vec {
|
||||
bytes := math.Float32bits(vec[j])
|
||||
common.Endian.PutUint32(bytesData[j*4:], bytes)
|
||||
}
|
||||
builder.Append(bytesData)
|
||||
}
|
||||
builder.Append(bytesData)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int) error {
|
||||
func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int, validData []bool) error {
|
||||
if w.finished {
|
||||
return errors.New("can't append data to finished float16 payload")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into float16 payload")
|
||||
}
|
||||
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast Float16Builder")
|
||||
}
|
||||
|
||||
byteLength := dim * 2
|
||||
length := len(data) / byteLength
|
||||
var numRows int
|
||||
if w.nullable && len(validData) > 0 {
|
||||
numRows = len(validData)
|
||||
validCount := 0
|
||||
for _, valid := range validData {
|
||||
if valid {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
expectedDataLen := validCount * byteLength
|
||||
if len(data) != expectedDataLen {
|
||||
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into float16 payload")
|
||||
}
|
||||
numRows = len(data) / byteLength
|
||||
if !w.nullable && len(validData) != 0 {
|
||||
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
builder.Reserve(length)
|
||||
for i := 0; i < length; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
if w.nullable {
|
||||
builder, ok := w.builder.(*array.BinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to BinaryBuilder for nullable Float16Vector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if len(validData) > 0 && !validData[i] {
|
||||
builder.AppendNull()
|
||||
} else {
|
||||
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Float16Vector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int) error {
|
||||
func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int, validData []bool) error {
|
||||
if w.finished {
|
||||
return errors.New("can't append data to finished BFloat16 payload")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into BFloat16 payload")
|
||||
}
|
||||
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast BFloat16Builder")
|
||||
}
|
||||
|
||||
byteLength := dim * 2
|
||||
length := len(data) / byteLength
|
||||
var numRows int
|
||||
if w.nullable && len(validData) > 0 {
|
||||
numRows = len(validData)
|
||||
validCount := 0
|
||||
for _, valid := range validData {
|
||||
if valid {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
expectedDataLen := validCount * byteLength
|
||||
if len(data) != expectedDataLen {
|
||||
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into BFloat16 payload")
|
||||
}
|
||||
numRows = len(data) / byteLength
|
||||
if !w.nullable && len(validData) != 0 {
|
||||
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
builder.Reserve(length)
|
||||
for i := 0; i < length; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
if w.nullable {
|
||||
builder, ok := w.builder.(*array.BinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to BinaryBuilder for nullable BFloat16Vector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if len(validData) > 0 && !validData[i] {
|
||||
builder.AppendNull()
|
||||
} else {
|
||||
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BFloat16Vector")
|
||||
}
|
||||
|
||||
builder.Reserve(numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -769,41 +927,107 @@ func (w *NativePayloadWriter) AddSparseFloatVectorToPayload(data *SparseFloatVec
|
||||
if w.finished {
|
||||
return errors.New("can't append data to finished sparse float vector payload")
|
||||
}
|
||||
|
||||
var numRows int
|
||||
if w.nullable && len(data.ValidData) > 0 {
|
||||
numRows = len(data.ValidData)
|
||||
validCount := 0
|
||||
for _, valid := range data.ValidData {
|
||||
if valid {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
if len(data.SparseFloatArray.Contents) != validCount {
|
||||
msg := fmt.Sprintf("when nullable, Contents length(%d) must equal to valid count(%d)", len(data.SparseFloatArray.Contents), validCount)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
numRows = len(data.SparseFloatArray.Contents)
|
||||
if !w.nullable && len(data.ValidData) != 0 {
|
||||
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(data.ValidData))
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
builder, ok := w.builder.(*array.BinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast SparseFloatVectorBuilder")
|
||||
}
|
||||
length := len(data.SparseFloatArray.Contents)
|
||||
builder.Reserve(length)
|
||||
for i := 0; i < length; i++ {
|
||||
builder.Append(data.SparseFloatArray.Contents[i])
|
||||
|
||||
builder.Reserve(numRows)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if w.nullable && len(data.ValidData) > 0 && !data.ValidData[i] {
|
||||
builder.AppendNull()
|
||||
} else {
|
||||
builder.Append(data.SparseFloatArray.Contents[dataIdx])
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int) error {
|
||||
func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int, validData []bool) error {
|
||||
if w.finished {
|
||||
return errors.New("can't append data to finished int8 vector payload")
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into int8 vector payload")
|
||||
var numRows int
|
||||
if w.nullable && len(validData) > 0 {
|
||||
numRows = len(validData)
|
||||
validCount := 0
|
||||
for _, valid := range validData {
|
||||
if valid {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
expectedDataLen := validCount * dim
|
||||
if len(data) != expectedDataLen {
|
||||
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
if len(data) == 0 {
|
||||
return errors.New("can't add empty msgs into int8 vector payload")
|
||||
}
|
||||
numRows = len(data) / dim
|
||||
if !w.nullable && len(validData) != 0 {
|
||||
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast Int8VectorBuilder")
|
||||
}
|
||||
if w.nullable {
|
||||
builder, ok := w.builder.(*array.BinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to BinaryBuilder for nullable Int8Vector")
|
||||
}
|
||||
|
||||
byteLength := dim
|
||||
length := len(data) / byteLength
|
||||
builder.Reserve(numRows)
|
||||
dataIdx := 0
|
||||
for i := 0; i < numRows; i++ {
|
||||
if len(validData) > 0 && !validData[i] {
|
||||
builder.AppendNull()
|
||||
} else {
|
||||
vec := data[dataIdx*dim : (dataIdx+1)*dim]
|
||||
vecBytes := arrow.Int8Traits.CastToBytes(vec)
|
||||
builder.Append(vecBytes)
|
||||
dataIdx++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
||||
if !ok {
|
||||
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Int8Vector")
|
||||
}
|
||||
|
||||
builder.Reserve(length)
|
||||
for i := 0; i < length; i++ {
|
||||
vec := data[i*dim : (i+1)*dim]
|
||||
vecBytes := arrow.Int8Traits.CastToBytes(vec)
|
||||
builder.Append(vecBytes)
|
||||
builder.Reserve(numRows)
|
||||
for i := 0; i < numRows; i++ {
|
||||
vec := data[i*dim : (i+1)*dim]
|
||||
vecBytes := arrow.Int8Traits.CastToBytes(vec)
|
||||
builder.Append(vecBytes)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -827,6 +1051,11 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error {
|
||||
[]string{"elementType", "dim"},
|
||||
[]string{fmt.Sprintf("%d", int32(*w.elementType)), fmt.Sprintf("%d", w.dim.GetValue())},
|
||||
)
|
||||
} else if w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType) {
|
||||
metadata = arrow.NewMetadata(
|
||||
[]string{"dim"},
|
||||
[]string{fmt.Sprintf("%d", w.dim.GetValue())},
|
||||
)
|
||||
}
|
||||
|
||||
field := arrow.Field{
|
||||
@ -849,7 +1078,8 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error {
|
||||
defer table.Release()
|
||||
|
||||
arrowWriterProps := pqarrow.DefaultWriterProps()
|
||||
if w.dataType == schemapb.DataType_ArrayOfVector {
|
||||
if w.dataType == schemapb.DataType_ArrayOfVector ||
|
||||
(w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType)) {
|
||||
// Store metadata in the Arrow writer properties
|
||||
arrowWriterProps = pqarrow.NewArrowWriterProperties(
|
||||
pqarrow.WithStoreSchema(),
|
||||
|
||||
@ -260,14 +260,14 @@ func TestPayloadWriter_Failed(t *testing.T) {
|
||||
err = w.FinishPayloadWriter()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = w.AddBinaryVectorToPayload(data, 8)
|
||||
err = w.AddBinaryVectorToPayload(data, 8, nil)
|
||||
require.Error(t, err)
|
||||
|
||||
w, err = NewPayloadWriter(schemapb.DataType_Int64)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, w)
|
||||
|
||||
err = w.AddBinaryVectorToPayload(data, 8)
|
||||
err = w.AddBinaryVectorToPayload(data, 8, nil)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
|
||||
@ -303,7 +303,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
||||
fmt.Printf("\t\t%d : %s\n", i, val[i])
|
||||
}
|
||||
case schemapb.DataType_BinaryVector:
|
||||
val, dim, err := reader.GetBinaryVectorFromPayload()
|
||||
val, dim, _, _, err := reader.GetBinaryVectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -318,7 +318,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
||||
fmt.Println()
|
||||
}
|
||||
case schemapb.DataType_Float16Vector:
|
||||
val, dim, err := reader.GetFloat16VectorFromPayload()
|
||||
val, dim, _, _, err := reader.GetFloat16VectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -333,7 +333,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
||||
fmt.Println()
|
||||
}
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
val, dim, err := reader.GetBFloat16VectorFromPayload()
|
||||
val, dim, _, _, err := reader.GetBFloat16VectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -349,7 +349,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
||||
}
|
||||
|
||||
case schemapb.DataType_FloatVector:
|
||||
val, dim, err := reader.GetFloatVectorFromPayload()
|
||||
val, dim, _, _, err := reader.GetFloatVectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -362,6 +362,20 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
case schemapb.DataType_Int8Vector:
|
||||
val, dim, _, _, err := reader.GetInt8VectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
length := len(val) / dim
|
||||
for i := 0; i < length; i++ {
|
||||
fmt.Printf("\t\t%d :", i)
|
||||
for j := 0; j < dim; j++ {
|
||||
idx := i*dim + j
|
||||
fmt.Printf(" %d", val[idx])
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
case schemapb.DataType_JSON:
|
||||
|
||||
rows, err := reader.GetPayloadLengthFromReader()
|
||||
@ -397,7 +411,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
|
||||
fmt.Printf("\t\t%d : %v\n", i, v)
|
||||
}
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
sparseData, _, err := reader.GetSparseFloatVectorFromPayload()
|
||||
sparseData, _, _, err := reader.GetSparseFloatVectorFromPayload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/uniquegenerator"
|
||||
)
|
||||
|
||||
@ -215,6 +216,23 @@ func TestPrintBinlogFiles(t *testing.T) {
|
||||
Description: "description_15",
|
||||
DataType: schemapb.DataType_Geometry,
|
||||
},
|
||||
{
|
||||
FieldID: 114,
|
||||
Name: "field_int8_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "description_16",
|
||||
DataType: schemapb.DataType_Int8Vector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 115,
|
||||
Name: "field_sparse_float_vector",
|
||||
IsPrimaryKey: false,
|
||||
Description: "description_17",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -280,6 +298,19 @@ func TestPrintBinlogFiles(t *testing.T) {
|
||||
{0x01, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A, 0x0D, 0x1B, 0x4F, 0x4F, 0x9A, 0x3D, 0x40, 0x03, 0xA6, 0xB4, 0xA6, 0xA4, 0xD2, 0xC5, 0xC0, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A},
|
||||
},
|
||||
},
|
||||
114: &Int8VectorFieldData{
|
||||
Data: []int8{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
Dim: 4,
|
||||
},
|
||||
115: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 100,
|
||||
Contents: [][]byte{
|
||||
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{1.1, 1.2, 1.3}),
|
||||
typeutil.CreateSparseFloatRow([]uint32{10, 20, 30}, []float32{2.1, 2.2, 2.3}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -344,6 +375,19 @@ func TestPrintBinlogFiles(t *testing.T) {
|
||||
{0x01, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A, 0x0D, 0x1B, 0x4F, 0x4F, 0x9A, 0x3D, 0x40, 0x03, 0xA6, 0xB4, 0xA6, 0xA4, 0xD2, 0xC5, 0xC0, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A},
|
||||
},
|
||||
},
|
||||
114: &Int8VectorFieldData{
|
||||
Data: []int8{11, 12, 13, 14, 15, 16, 17, 18},
|
||||
Dim: 4,
|
||||
},
|
||||
115: &SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: 100,
|
||||
Contents: [][]byte{
|
||||
typeutil.CreateSparseFloatRow([]uint32{5, 6, 7}, []float32{3.1, 3.2, 3.3}),
|
||||
typeutil.CreateSparseFloatRow([]uint32{15, 25, 35}, []float32{4.1, 4.2, 4.3}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
firstBlobs, err := insertCodec.Serialize(1, 1, insertDataFirst)
|
||||
|
||||
@ -38,8 +38,28 @@ func ConvertToArrowSchema(schema *schemapb.CollectionSchema, useFieldID bool) (*
|
||||
}
|
||||
|
||||
arrowType := serdeMap[field.DataType].arrowType(dim, elementType)
|
||||
|
||||
if field.GetNullable() {
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector,
|
||||
schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_Int8Vector:
|
||||
arrowType = arrow.BinaryTypes.Binary
|
||||
}
|
||||
}
|
||||
|
||||
arrowField := ConvertToArrowField(field, arrowType, useFieldID)
|
||||
|
||||
if field.GetNullable() {
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector,
|
||||
schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_Int8Vector:
|
||||
arrowField.Metadata = arrow.NewMetadata(
|
||||
[]string{packed.ArrowFieldIdMetadataKey, "dim"},
|
||||
[]string{strconv.Itoa(int(field.GetFieldID())), strconv.Itoa(dim)},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Add extra metadata for ArrayOfVector
|
||||
if field.DataType == schemapb.DataType_ArrayOfVector {
|
||||
arrowField.Metadata = arrow.NewMetadata(
|
||||
|
||||
@ -43,12 +43,18 @@ func TestConvertArrowSchema(t *testing.T) {
|
||||
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 16, Name: "field15", DataType: schemapb.DataType_Int8Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 17, Name: "field16", DataType: schemapb.DataType_BinaryVector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 18, Name: "field17", DataType: schemapb.DataType_FloatVector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 19, Name: "field18", DataType: schemapb.DataType_Float16Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 20, Name: "field19", DataType: schemapb.DataType_BFloat16Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 21, Name: "field20", DataType: schemapb.DataType_Int8Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
|
||||
{FieldID: 22, Name: "field21", DataType: schemapb.DataType_SparseFloatVector, Nullable: true},
|
||||
}
|
||||
|
||||
StructArrayFieldSchemas := []*schemapb.StructArrayFieldSchema{
|
||||
{FieldID: 17, Name: "struct_field0", Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 18, Name: "field16", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
|
||||
{FieldID: 19, Name: "field17", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float},
|
||||
{FieldID: 23, Name: "struct_field0", Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 24, Name: "field22", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
|
||||
{FieldID: 25, Name: "field23", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float},
|
||||
}},
|
||||
}
|
||||
|
||||
@ -59,6 +65,14 @@ func TestConvertArrowSchema(t *testing.T) {
|
||||
arrowSchema, err := ConvertToArrowSchema(schema, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(fieldSchemas)+len(StructArrayFieldSchemas[0].Fields), len(arrowSchema.Fields()))
|
||||
|
||||
for i, field := range arrowSchema.Fields() {
|
||||
if i >= 16 && i <= 20 {
|
||||
dimVal, ok := field.Metadata.GetValue("dim")
|
||||
assert.True(t, ok, "nullable vector field should have dim metadata")
|
||||
assert.Equal(t, "128", dimVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertArrowSchemaWithoutDim(t *testing.T) {
|
||||
|
||||
@ -553,8 +553,12 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
|
||||
b.AppendNull()
|
||||
return true
|
||||
}
|
||||
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
|
||||
if v, ok := v.([]byte); ok {
|
||||
if v, ok := v.([]byte); ok {
|
||||
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
|
||||
builder.Append(v)
|
||||
return true
|
||||
}
|
||||
if builder, ok := b.(*array.BinaryBuilder); ok {
|
||||
builder.Append(v)
|
||||
return true
|
||||
}
|
||||
@ -607,14 +611,21 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
|
||||
b.AppendNull()
|
||||
return true
|
||||
}
|
||||
var bytesData []byte
|
||||
if vv, ok := v.([]byte); ok {
|
||||
bytesData = vv
|
||||
} else if vv, ok := v.([]int8); ok {
|
||||
bytesData = arrow.Int8Traits.CastToBytes(vv)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
|
||||
if vv, ok := v.([]byte); ok {
|
||||
builder.Append(vv)
|
||||
return true
|
||||
} else if vv, ok := v.([]int8); ok {
|
||||
builder.Append(arrow.Int8Traits.CastToBytes(vv))
|
||||
return true
|
||||
}
|
||||
builder.Append(bytesData)
|
||||
return true
|
||||
}
|
||||
if builder, ok := b.(*array.BinaryBuilder); ok {
|
||||
builder.Append(bytesData)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
@ -643,15 +654,19 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
|
||||
b.AppendNull()
|
||||
return true
|
||||
}
|
||||
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
|
||||
if vv, ok := v.([]float32); ok {
|
||||
dim := len(vv)
|
||||
byteLength := dim * 4
|
||||
bytesData := make([]byte, byteLength)
|
||||
for i, vec := range vv {
|
||||
bytes := math.Float32bits(vec)
|
||||
common.Endian.PutUint32(bytesData[i*4:], bytes)
|
||||
}
|
||||
if vv, ok := v.([]float32); ok {
|
||||
dim := len(vv)
|
||||
byteLength := dim * 4
|
||||
bytesData := make([]byte, byteLength)
|
||||
for i, vec := range vv {
|
||||
bytes := math.Float32bits(vec)
|
||||
common.Endian.PutUint32(bytesData[i*4:], bytes)
|
||||
}
|
||||
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
|
||||
builder.Append(bytesData)
|
||||
return true
|
||||
}
|
||||
if builder, ok := b.(*array.BinaryBuilder); ok {
|
||||
builder.Append(bytesData)
|
||||
return true
|
||||
}
|
||||
@ -987,7 +1002,16 @@ func newSingleFieldRecordWriter(field *schemapb.FieldSchema, writer io.Writer, o
|
||||
[]string{fmt.Sprintf("%d", int32(elementType)), fmt.Sprintf("%d", dim)},
|
||||
)
|
||||
}
|
||||
arrowType = serdeMap[field.DataType].arrowType(int(dim), elementType)
|
||||
|
||||
if field.GetNullable() && typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) {
|
||||
arrowType = arrow.BinaryTypes.Binary
|
||||
fieldMetadata = arrow.NewMetadata(
|
||||
[]string{"dim"},
|
||||
[]string{fmt.Sprintf("%d", dim)},
|
||||
)
|
||||
} else {
|
||||
arrowType = serdeMap[field.DataType].arrowType(int(dim), elementType)
|
||||
}
|
||||
|
||||
w := &singleFieldRecordWriter{
|
||||
fieldId: field.FieldID,
|
||||
@ -1199,10 +1223,40 @@ func BuildRecord(b *array.RecordBuilder, data *InsertData, schema *schemapb.Coll
|
||||
elementType = field.GetElementType()
|
||||
}
|
||||
|
||||
for j := 0; j < fieldData.RowNum(); j++ {
|
||||
ok = typeEntry.serialize(fBuilder, fieldData.GetRow(j), elementType)
|
||||
if !ok {
|
||||
return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String()))
|
||||
if field.GetNullable() && typeutil.IsVectorType(field.DataType) {
|
||||
var validData []bool
|
||||
switch fd := fieldData.(type) {
|
||||
case *FloatVectorFieldData:
|
||||
validData = fd.ValidData
|
||||
case *BinaryVectorFieldData:
|
||||
validData = fd.ValidData
|
||||
case *Float16VectorFieldData:
|
||||
validData = fd.ValidData
|
||||
case *BFloat16VectorFieldData:
|
||||
validData = fd.ValidData
|
||||
case *SparseFloatVectorFieldData:
|
||||
validData = fd.ValidData
|
||||
case *Int8VectorFieldData:
|
||||
validData = fd.ValidData
|
||||
}
|
||||
// Use len(validData) as logical row count, GetRow takes logical index
|
||||
for j := 0; j < len(validData); j++ {
|
||||
if !validData[j] {
|
||||
fBuilder.(*array.BinaryBuilder).AppendNull()
|
||||
} else {
|
||||
rowData := fieldData.GetRow(j)
|
||||
ok = typeEntry.serialize(fBuilder, rowData, elementType)
|
||||
if !ok {
|
||||
return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for j := 0; j < fieldData.RowNum(); j++ {
|
||||
ok = typeEntry.serialize(fBuilder, fieldData.GetRow(j), elementType)
|
||||
if !ok {
|
||||
return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String()))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user