milvus/tests/go_client/testcases/nullable_default_value_test.go
marcelo-cjl 3b599441fd
feat: Add nullable vector support for proxy and querynode (#46305)
related: #45993 

This commit extends nullable vector support to the proxy layer,
querynode,
and adds comprehensive validation, search reduce, and field data
handling
    for nullable vectors with sparse storage.
    
    Proxy layer changes:
- Update validate_util.go checkAligned() with getExpectedVectorRows()
helper
      to validate nullable vector field alignment using valid data count
- Update checkFloatVectorFieldData/checkSparseFloatVectorFieldData for
      nullable vector validation with proper row count expectations
- Add FieldDataIdxComputer in typeutil/schema.go for logical-to-physical
      index translation during search reduce operations
- Update search_reduce_util.go reduceSearchResultData to use
idxComputers
      for correct field data indexing with nullable vectors
- Update task.go, task_query.go, task_upsert.go for nullable vector
handling
    - Update msg_pack.go with nullable vector field data processing
    
    QueryNode layer changes:
    - Update segments/result.go for nullable vector result handling
- Update segments/search_reduce.go with nullable vector offset
translation
    
    Storage and index changes:
- Update data_codec.go and utils.go for nullable vector serialization
- Update indexcgowrapper/dataset.go and index.go for nullable vector
indexing
    
    Utility changes:
- Add FieldDataIdxComputer struct with Compute() method for efficient
      logical-to-physical index mapping across multiple field data
- Update EstimateEntitySize() and AppendFieldData() with fieldIdxs
parameter
    - Update funcutil.go with nullable vector support functions

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Full support for nullable vector fields (float, binary, float16,
bfloat16, int8, sparse) across ingest, storage, indexing, search and
retrieval; logical↔physical offset mapping preserves row semantics.
  * Client: compaction control and compaction-state APIs.

* **Bug Fixes**
* Improved validation for adding vector fields (nullable + dimension
checks) and corrected search/query behavior for nullable vectors.

* **Chores**
  * Persisted validity maps with indexes and on-disk formats.

* **Tests**
  * Extensive new and updated end-to-end nullable-vector tests.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
2025-12-24 10:13:19 +08:00

3283 lines
145 KiB
Go

package testcases
import (
"fmt"
"math"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/milvus-io/milvus/client/v2/column"
"github.com/milvus-io/milvus/client/v2/entity"
"github.com/milvus-io/milvus/client/v2/index"
client "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/tests/go_client/common"
hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper"
)
func int64SliceToString(ids []int64) string {
strs := make([]string, len(ids))
for i, id := range ids {
strs[i] = fmt.Sprintf("%d", id)
}
return strings.Join(strs, ", ")
}
type NullableVectorType struct {
Name string
FieldType entity.FieldType
}
func GetVectorTypes() []NullableVectorType {
return []NullableVectorType{
{"FloatVector", entity.FieldTypeFloatVector},
{"BinaryVector", entity.FieldTypeBinaryVector},
{"Float16Vector", entity.FieldTypeFloat16Vector},
{"BFloat16Vector", entity.FieldTypeBFloat16Vector},
{"Int8Vector", entity.FieldTypeInt8Vector},
{"SparseVector", entity.FieldTypeSparseVector},
}
}
func GetNullPercents() []int {
return []int{0, 30}
}
type NullableVectorTestData struct {
ValidData []bool
ValidCount int
PkToVecIdx map[int64]int
OriginalVectors interface{}
VecColumn column.Column
SearchVec entity.Vector
}
func GenerateNullableVectorTestData(t *testing.T, vt NullableVectorType, nb int, nullPercent int, fieldName string) *NullableVectorTestData {
data := &NullableVectorTestData{
ValidData: make([]bool, nb),
PkToVecIdx: make(map[int64]int),
}
for i := range nb {
data.ValidData[i] = (i % 100) >= nullPercent
if data.ValidData[i] {
data.ValidCount++
}
}
vecIdx := 0
for i := range nb {
if data.ValidData[i] {
data.PkToVecIdx[int64(i)] = vecIdx
vecIdx++
}
}
var err error
switch vt.FieldType {
case entity.FieldTypeFloatVector:
vectors := make([][]float32, data.ValidCount)
for i := range data.ValidCount {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32(i*common.DefaultDim+j) / 10000.0
}
vectors[i] = vec
}
data.OriginalVectors = vectors
data.VecColumn, err = column.NewNullableColumnFloatVector(fieldName, common.DefaultDim, vectors, data.ValidData)
if data.ValidCount > 0 {
data.SearchVec = entity.FloatVector(vectors[0])
}
case entity.FieldTypeBinaryVector:
byteDim := common.DefaultDim / 8
vectors := make([][]byte, data.ValidCount)
for i := range data.ValidCount {
vec := make([]byte, byteDim)
for j := range byteDim {
vec[j] = byte((i + j) % 256)
}
vectors[i] = vec
}
data.OriginalVectors = vectors
data.VecColumn, err = column.NewNullableColumnBinaryVector(fieldName, common.DefaultDim, vectors, data.ValidData)
if data.ValidCount > 0 {
data.SearchVec = entity.BinaryVector(vectors[0])
}
case entity.FieldTypeFloat16Vector:
vectors := make([][]byte, data.ValidCount)
for i := range data.ValidCount {
vectors[i] = common.GenFloat16Vector(common.DefaultDim)
}
data.OriginalVectors = vectors
data.VecColumn, err = column.NewNullableColumnFloat16Vector(fieldName, common.DefaultDim, vectors, data.ValidData)
if data.ValidCount > 0 {
data.SearchVec = entity.Float16Vector(vectors[0])
}
case entity.FieldTypeBFloat16Vector:
vectors := make([][]byte, data.ValidCount)
for i := range data.ValidCount {
vectors[i] = common.GenBFloat16Vector(common.DefaultDim)
}
data.OriginalVectors = vectors
data.VecColumn, err = column.NewNullableColumnBFloat16Vector(fieldName, common.DefaultDim, vectors, data.ValidData)
if data.ValidCount > 0 {
data.SearchVec = entity.BFloat16Vector(vectors[0])
}
case entity.FieldTypeInt8Vector:
vectors := make([][]int8, data.ValidCount)
for i := range data.ValidCount {
vec := make([]int8, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = int8((i + j) % 127)
}
vectors[i] = vec
}
data.OriginalVectors = vectors
data.VecColumn, err = column.NewNullableColumnInt8Vector(fieldName, common.DefaultDim, vectors, data.ValidData)
if data.ValidCount > 0 {
data.SearchVec = entity.Int8Vector(vectors[0])
}
case entity.FieldTypeSparseVector:
vectors := make([]entity.SparseEmbedding, data.ValidCount)
for i := range data.ValidCount {
positions := []uint32{0, uint32(i + 1), uint32(i + 1000)}
values := []float32{1.0, float32(i+1) / 1000.0, 0.1}
vectors[i], err = entity.NewSliceSparseEmbedding(positions, values)
common.CheckErr(t, err, true)
}
data.OriginalVectors = vectors
data.VecColumn, err = column.NewNullableColumnSparseFloatVector(fieldName, vectors, data.ValidData)
if data.ValidCount > 0 {
data.SearchVec = vectors[0]
}
}
common.CheckErr(t, err, true)
return data
}
type IndexConfig struct {
Name string
IndexType string
MetricType entity.MetricType
Params map[string]string
}
func GetIndexesForVectorType(fieldType entity.FieldType) []IndexConfig {
switch fieldType {
case entity.FieldTypeFloatVector:
return []IndexConfig{
{"FLAT", "FLAT", entity.L2, nil},
{"IVF_FLAT", "IVF_FLAT", entity.L2, map[string]string{"nlist": "128"}},
{"IVF_SQ8", "IVF_SQ8", entity.L2, map[string]string{"nlist": "128"}},
{"IVF_PQ", "IVF_PQ", entity.L2, map[string]string{"nlist": "128", "m": "8", "nbits": "8"}},
{"HNSW", "HNSW", entity.L2, map[string]string{"M": "16", "efConstruction": "200"}},
{"SCANN", "SCANN", entity.L2, map[string]string{"nlist": "128", "with_raw_data": "true"}},
// {"DISKANN", "DISKANN", entity.L2, nil}, // Skip DISKANN for now
}
case entity.FieldTypeBinaryVector:
return []IndexConfig{
{"BIN_FLAT", "BIN_FLAT", entity.JACCARD, nil},
{"BIN_IVF_FLAT", "BIN_IVF_FLAT", entity.JACCARD, map[string]string{"nlist": "128"}},
}
case entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector:
return []IndexConfig{
{"FLAT", "FLAT", entity.L2, nil},
{"IVF_FLAT", "IVF_FLAT", entity.L2, map[string]string{"nlist": "128"}},
{"IVF_SQ8", "IVF_SQ8", entity.L2, map[string]string{"nlist": "128"}},
{"HNSW", "HNSW", entity.L2, map[string]string{"M": "16", "efConstruction": "200"}},
}
case entity.FieldTypeInt8Vector:
return []IndexConfig{
{"HNSW", "HNSW", entity.COSINE, map[string]string{"M": "16", "efConstruction": "200"}},
}
case entity.FieldTypeSparseVector:
return []IndexConfig{
{"SPARSE_INVERTED_INDEX", "SPARSE_INVERTED_INDEX", entity.IP, map[string]string{"drop_ratio_build": "0.1"}},
{"SPARSE_WAND", "SPARSE_WAND", entity.IP, map[string]string{"drop_ratio_build": "0.1"}},
}
default:
return []IndexConfig{
{"FLAT", "FLAT", entity.L2, nil},
}
}
}
func CreateIndexFromConfig(fieldName string, cfg IndexConfig) index.Index {
params := map[string]string{
index.MetricTypeKey: string(cfg.MetricType),
index.IndexTypeKey: cfg.IndexType,
}
for k, v := range cfg.Params {
params[k] = v
}
return index.NewGenericIndex(fieldName, params)
}
func CreateNullableVectorIndex(vt NullableVectorType) index.Index {
return CreateNullableVectorIndexWithFieldName(vt, "vector")
}
func CreateNullableVectorIndexWithFieldName(vt NullableVectorType, fieldName string) index.Index {
indexes := GetIndexesForVectorType(vt.FieldType)
if len(indexes) > 0 {
return CreateIndexFromConfig(fieldName, indexes[0])
}
return index.NewGenericIndex(fieldName, map[string]string{
index.MetricTypeKey: string(entity.L2),
index.IndexTypeKey: "FLAT",
})
}
func VerifyNullableVectorData(t *testing.T, vt NullableVectorType, queryResult client.ResultSet, pkToVecIdx map[int64]int, originalVectors interface{}, context string) {
pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64)
vecCol := queryResult.GetColumn("vector")
for i := 0; i < queryResult.ResultCount; i++ {
pk, _ := pkCol.GetAsInt64(i)
isNull, _ := vecCol.IsNull(i)
if origIdx, ok := pkToVecIdx[pk]; ok {
require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
switch vt.FieldType {
case entity.FieldTypeFloatVector:
vectors := originalVectors.([][]float32)
queriedVec := []float32(vecData.(entity.FloatVector))
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.InDelta(t, origVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeInt8Vector:
vectors := originalVectors.([][]int8)
queriedVec := []int8(vecData.(entity.Int8Vector))
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector element %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeBinaryVector:
vectors := originalVectors.([][]byte)
queriedVec := []byte(vecData.(entity.BinaryVector))
byteDim := common.DefaultDim / 8
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector byte %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeFloat16Vector:
queriedVec := []byte(vecData.(entity.Float16Vector))
byteDim := common.DefaultDim * 2
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
case entity.FieldTypeBFloat16Vector:
queriedVec := []byte(vecData.(entity.BFloat16Vector))
byteDim := common.DefaultDim * 2
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
case entity.FieldTypeSparseVector:
vectors := originalVectors.([]entity.SparseEmbedding)
queriedVec := vecData.(entity.SparseEmbedding)
origVec := vectors[origIdx]
require.EqualValues(t, origVec.Len(), queriedVec.Len(), "%s: sparse vector length should match for pk %d", context, pk)
for j := 0; j < origVec.Len(); j++ {
origPos, origVal, _ := origVec.Get(j)
queriedPos, queriedVal, _ := queriedVec.Get(j)
require.EqualValues(t, origPos, queriedPos, "%s: sparse vector position %d should match for pk %d", context, j, pk)
require.InDelta(t, origVal, queriedVal, 1e-6, "%s: sparse vector value %d should match for pk %d", context, j, pk)
}
}
} else {
require.True(t, isNull, "%s: vector should be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk)
}
}
}
func VerifyNullableVectorDataWithFieldName(t *testing.T, vt NullableVectorType, queryResult client.ResultSet, pkToVecIdx map[int64]int, originalVectors interface{}, fieldName string, context string) {
pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64)
vecCol := queryResult.GetColumn(fieldName)
for i := 0; i < queryResult.ResultCount; i++ {
pk, _ := pkCol.GetAsInt64(i)
isNull, _ := vecCol.IsNull(i)
if origIdx, ok := pkToVecIdx[pk]; ok {
require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
switch vt.FieldType {
case entity.FieldTypeFloatVector:
vectors := originalVectors.([][]float32)
queriedVec := []float32(vecData.(entity.FloatVector))
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.InDelta(t, origVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeInt8Vector:
vectors := originalVectors.([][]int8)
queriedVec := []int8(vecData.(entity.Int8Vector))
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector element %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeBinaryVector:
vectors := originalVectors.([][]byte)
queriedVec := []byte(vecData.(entity.BinaryVector))
byteDim := common.DefaultDim / 8
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector byte %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeFloat16Vector:
queriedVec := []byte(vecData.(entity.Float16Vector))
byteDim := common.DefaultDim * 2
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
case entity.FieldTypeBFloat16Vector:
queriedVec := []byte(vecData.(entity.BFloat16Vector))
byteDim := common.DefaultDim * 2
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
case entity.FieldTypeSparseVector:
vectors := originalVectors.([]entity.SparseEmbedding)
queriedVec := vecData.(entity.SparseEmbedding)
origVec := vectors[origIdx]
require.EqualValues(t, origVec.Len(), queriedVec.Len(), "%s: sparse vector length should match for pk %d", context, pk)
for j := 0; j < origVec.Len(); j++ {
origPos, origVal, _ := origVec.Get(j)
queriedPos, queriedVal, _ := queriedVec.Get(j)
require.EqualValues(t, origPos, queriedPos, "%s: sparse vector position %d should match for pk %d", context, j, pk)
require.InDelta(t, origVal, queriedVal, 1e-6, "%s: sparse vector value %d should match for pk %d", context, j, pk)
}
}
} else {
require.True(t, isNull, "%s: vector should be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk)
}
}
}
// create collection with nullable fields and insert with column / nullableColumn
func TestNullableDefault(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create fields: pk + floatVec + all nullable scalar fields
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
schema := entity.NewSchema().WithName(common.GenRandomString("nullable_default_value", 5)).WithField(pkField).WithField(vecField)
// create collection with all supported nullable fields
expNullableFields := make([]string, 0)
for _, fieldType := range hp.GetAllNullableFieldType() {
nullableField := entity.NewField().WithName(common.GenRandomString("null", 5)).WithDataType(fieldType).WithNullable(true)
if fieldType == entity.FieldTypeVarChar {
nullableField.WithMaxLength(common.TestMaxLen)
}
if fieldType == entity.FieldTypeArray {
nullableField.WithElementType(entity.FieldTypeInt64).WithMaxCapacity(common.TestCapacity)
}
schema.WithField(nullableField)
expNullableFields = append(expNullableFields, nullableField.Name)
}
// create collection
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// describe collection and check nullable fields
descCollection, err := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(schema.CollectionName))
common.CheckErr(t, err, true)
common.CheckFieldsNullable(t, expNullableFields, descCollection.Schema)
prepare := hp.CollPrepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// insert data with default column
defColumnOpt := hp.TNewColumnOptions()
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), defColumnOpt)
// query with null expr
for _, nullField := range expNullableFields {
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s is null", nullField)).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 0, count)
}
// insert data with nullable column
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 1
}
columnOpt := hp.TNewColumnOptions()
for _, name := range expNullableFields {
columnOpt = columnOpt.WithColumnOption(name, hp.TNewDataOption().TWithValidData(validData))
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), columnOpt)
hp.CollPrepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// query with null expr
for _, nullField := range expNullableFields {
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s is null", nullField)).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb/2, count)
}
}
// create collection with default value and insert with column / nullableColumn
func TestDefaultValueDefault(t *testing.T) {
t.Skip("set defaultValue and insert with default column gets unexpected error, waiting for fix")
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create fields: pk + floatVec + default value scalar fields
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
schema := entity.NewSchema().WithName(common.GenRandomString("default_value", 5)).WithField(pkField).WithField(vecField)
// create collection with all supported nullable fields
defaultBoolField := entity.NewField().WithName(common.GenRandomString("bool", 3)).WithDataType(entity.FieldTypeBool).WithDefaultValueBool(true)
defaultInt8Field := entity.NewField().WithName(common.GenRandomString("int8", 3)).WithDataType(entity.FieldTypeInt8).WithDefaultValueInt(-1)
defaultInt16Field := entity.NewField().WithName(common.GenRandomString("int16", 3)).WithDataType(entity.FieldTypeInt16).WithDefaultValueInt(4)
defaultInt32Field := entity.NewField().WithName(common.GenRandomString("int32", 3)).WithDataType(entity.FieldTypeInt32).WithDefaultValueInt(2000)
defaultInt64Field := entity.NewField().WithName(common.GenRandomString("int64", 3)).WithDataType(entity.FieldTypeInt64).WithDefaultValueLong(10000)
defaultFloatField := entity.NewField().WithName(common.GenRandomString("float", 3)).WithDataType(entity.FieldTypeFloat).WithDefaultValueFloat(-1.0)
defaultDoubleField := entity.NewField().WithName(common.GenRandomString("double", 3)).WithDataType(entity.FieldTypeDouble).WithDefaultValueDouble(math.MaxFloat64)
defaultVarCharField := entity.NewField().WithName(common.GenRandomString("varchar", 3)).WithDataType(entity.FieldTypeVarChar).WithDefaultValueString("default").WithMaxLength(common.TestMaxLen)
schema.WithField(defaultBoolField).WithField(defaultInt8Field).WithField(defaultInt16Field).WithField(defaultInt32Field).WithField(defaultInt64Field).WithField(defaultFloatField).WithField(defaultDoubleField).WithField(defaultVarCharField)
// create collection
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, true)
coll, _ := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(schema.CollectionName))
common.CheckFieldsDefaultValue(t, map[string]interface{}{
defaultBoolField.Name: true,
defaultInt8Field.Name: int8(-1),
defaultInt16Field.Name: int16(4),
defaultInt32Field.Name: int32(2000),
defaultInt64Field.Name: int64(10000),
defaultFloatField.Name: float32(-1.0),
defaultDoubleField.Name: math.MaxFloat64,
defaultVarCharField.Name: "default",
}, coll.Schema)
prepare := hp.CollPrepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// insert data with default column
defColumnOpt := hp.TNewColumnOptions()
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), defColumnOpt)
// query with null expr
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == -1", defaultInt8Field.Name)).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 0, count)
// insert data
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
columnOpt := hp.TNewColumnOptions()
for _, name := range []string{
defaultBoolField.Name, defaultInt8Field.Name, defaultInt16Field.Name, defaultInt32Field.Name,
defaultInt64Field.Name, defaultFloatField.Name, defaultDoubleField.Name, defaultVarCharField.Name,
} {
columnOpt = columnOpt.WithColumnOption(name, hp.TNewDataOption().TWithValidData(validData))
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), columnOpt)
// query with expr check default value
type exprCount struct {
expr string
count int64
}
exprCounts := []exprCount{
{expr: fmt.Sprintf("%s == %t", defaultBoolField.Name, true), count: common.DefaultNb * 5 / 4},
{expr: fmt.Sprintf("%s == %d", defaultInt8Field.Name, -1), count: common.DefaultNb/2 + 10}, // int8 [-128, 127]
{expr: fmt.Sprintf("%s == %d", defaultInt16Field.Name, 4), count: common.DefaultNb/2 + 2},
{expr: fmt.Sprintf("%s == %d", defaultInt32Field.Name, 2000), count: common.DefaultNb / 2},
{expr: fmt.Sprintf("%s == %d", defaultInt64Field.Name, 10000), count: common.DefaultNb / 2},
{expr: fmt.Sprintf("%s == %f", defaultFloatField.Name, -1.0), count: common.DefaultNb / 2},
{expr: fmt.Sprintf("%s == %f", defaultDoubleField.Name, math.MaxFloat64), count: common.DefaultNb / 2},
{expr: fmt.Sprintf("%s == '%s'", defaultVarCharField.Name, "default"), count: common.DefaultNb / 2},
}
for _, exprCount := range exprCounts {
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(exprCount.expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.Equal(t, exprCount.count, count)
}
}
func TestNullableInvalid(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
pkField := entity.NewField().WithName(common.GenRandomString("pk", 3)).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName(common.GenRandomString("vec", 3)).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
// pk field not support null
pkFieldNull := entity.NewField().WithName(common.GenRandomString("pk", 3)).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithNullable(true)
schema := entity.NewSchema().WithName(common.GenRandomString("nullable_invalid_field", 5)).WithField(pkFieldNull).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, false, "primary field not support null")
supportedNullableVectorTypes := []entity.FieldType{entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector, entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector, entity.FieldTypeSparseVector, entity.FieldTypeInt8Vector}
for _, fieldType := range supportedNullableVectorTypes {
nullableVectorField := entity.NewField().WithName(common.GenRandomString("null", 3)).WithDataType(fieldType).WithNullable(true)
if fieldType != entity.FieldTypeSparseVector {
nullableVectorField.WithDim(128)
}
schema := entity.NewSchema().WithName(common.GenRandomString("nullable_vector", 5)).WithField(pkField).WithField(nullableVectorField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, true)
mc.DropCollection(ctx, client.NewDropCollectionOption(schema.CollectionName))
}
// partition-key field not support null
partitionField := entity.NewField().WithName(common.GenRandomString("partition", 3)).WithDataType(entity.FieldTypeInt64).WithIsPartitionKey(true).WithNullable(true)
schema = entity.NewSchema().WithName(common.GenRandomString("nullable_invalid_field", 5)).WithField(pkField).WithField(vecField).WithField(partitionField)
err = mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, false, "partition key field not support nullable")
}
func TestDefaultValueInvalid(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
pkField := entity.NewField().WithName(common.GenRandomString("pk", 3)).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName(common.GenRandomString("vec", 3)).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
// pk field not support default value
pkFieldNull := entity.NewField().WithName(common.GenRandomString("pk", 3)).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithDefaultValueLong(1)
schema := entity.NewSchema().WithName(common.GenRandomString("def_invalid_field", 5)).WithField(pkFieldNull).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, false, "primary field not support default_value")
// vector type not support default value
notSupportedDefaultValueDataTypes := []entity.FieldType{entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector, entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector, entity.FieldTypeSparseVector, entity.FieldTypeInt8Vector}
for _, fieldType := range notSupportedDefaultValueDataTypes {
nullableVectorField := entity.NewField().WithName(common.GenRandomString("def", 3)).WithDataType(fieldType).WithDefaultValueFloat(2.0)
if fieldType != entity.FieldTypeSparseVector {
nullableVectorField.WithDim(128)
}
schema := entity.NewSchema().WithName(common.GenRandomString("def_invalid_field", 5)).WithField(pkField).WithField(nullableVectorField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, false, "type not support default_value")
}
// json and array type not support default value
notSupportedDefaultValueDataTypes = []entity.FieldType{entity.FieldTypeJSON, entity.FieldTypeArray}
for _, fieldType := range notSupportedDefaultValueDataTypes {
nullableVectorField := entity.NewField().WithName(common.GenRandomString("def", 3)).WithDataType(fieldType).WithElementType(entity.FieldTypeFloat).WithMaxCapacity(100).WithDefaultValueFloat(2.0)
schema := entity.NewSchema().WithName(common.GenRandomString("def_invalid_field", 5)).WithField(pkField).WithField(vecField).WithField(nullableVectorField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, false, "type not support default_value")
}
}
func TestDefaultValueInvalidValue(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
pkField := entity.NewField().WithName(common.GenRandomString("pk", 3)).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName(common.GenRandomString("vec", 3)).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
// 定义测试用例
testCases := []struct {
name string
fieldName string
dataType entity.FieldType
setupField func() *entity.Field
expectedErr string
}{
{
name: "varchar field default_value_length > max_length",
fieldName: common.DefaultVarcharFieldName,
dataType: entity.FieldTypeVarChar,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultVarcharFieldName).
WithDataType(entity.FieldTypeVarChar).
WithDefaultValueString("defaultaaaaaaaaa").
WithMaxLength(2)
},
expectedErr: "the length (16) of string exceeds max length (2)",
},
{
name: "varchar field with int default_value",
fieldName: common.DefaultVarcharFieldName,
dataType: entity.FieldTypeVarChar,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultVarcharFieldName).
WithDataType(entity.FieldTypeVarChar).
WithDefaultValueInt(2).
WithMaxLength(100)
},
expectedErr: fmt.Sprintf("type (VarChar) of field (%s) is not equal to the type(DataType_Int) of default_value", common.DefaultVarcharFieldName),
},
{
name: "int32 field with int64 default_value",
fieldName: common.DefaultInt32FieldName,
dataType: entity.FieldTypeInt32,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultInt32FieldName).
WithDataType(entity.FieldTypeInt32).
WithDefaultValueLong(2)
},
expectedErr: fmt.Sprintf("type (Int32) of field (%s) is not equal to the type(DataType_Int64) of default_value", common.DefaultInt32FieldName),
},
{
name: "int64 field with int default_value",
fieldName: common.DefaultInt64FieldName,
dataType: entity.FieldTypeInt64,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultInt64FieldName).
WithDataType(entity.FieldTypeInt64).
WithDefaultValueInt(2)
},
expectedErr: fmt.Sprintf("type (Int64) of field (%s) is not equal to the type(DataType_Int) of default_value", common.DefaultInt64FieldName),
},
{
name: "float field with double default_value",
fieldName: common.DefaultFloatFieldName,
dataType: entity.FieldTypeFloat,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultFloatFieldName).
WithDataType(entity.FieldTypeFloat).
WithDefaultValueDouble(2.6)
},
expectedErr: fmt.Sprintf("type (Float) of field (%s) is not equal to the type(DataType_Double) of default_value", common.DefaultFloatFieldName),
},
{
name: "double field with varchar default_value",
fieldName: common.DefaultDoubleFieldName,
dataType: entity.FieldTypeDouble,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultDoubleFieldName).
WithDataType(entity.FieldTypeDouble).
WithDefaultValueString("2.6")
},
expectedErr: fmt.Sprintf("type (Double) of field (%s) is not equal to the type(DataType_VarChar) of default_value", common.DefaultDoubleFieldName),
},
}
// 执行测试用例
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defField := tc.setupField()
schema := entity.NewSchema().WithName("def_invalid_field").WithField(pkField).WithField(vecField).WithField(defField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, false, tc.expectedErr)
})
}
}
func TestDefaultValueOutOfRange(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
pkField := entity.NewField().WithName(common.GenRandomString("pk", 3)).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName(common.GenRandomString("vec", 3)).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
testCases := []struct {
name string
fieldName string
dataType entity.FieldType
setupField func() *entity.Field
expectedErr string
}{
{
name: "int8 field with out_of_range default_value",
fieldName: common.DefaultInt8FieldName,
dataType: entity.FieldTypeInt8,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultInt8FieldName).WithDataType(entity.FieldTypeInt8).WithDefaultValueInt(128)
},
expectedErr: "[128 out of range -128 <= value <= 127]",
},
{
name: "int16 field with out_of_range default_value",
fieldName: common.DefaultInt16FieldName,
dataType: entity.FieldTypeInt16,
setupField: func() *entity.Field {
return entity.NewField().WithName(common.DefaultInt16FieldName).WithDataType(entity.FieldTypeInt16).WithDefaultValueInt(-32769)
},
expectedErr: "[-32769 out of range -32768 <= value <= 32767]",
},
}
// 执行测试用例
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
defField := tc.setupField()
schema := entity.NewSchema().WithName(common.GenRandomString("def", 5)).WithField(pkField).WithField(vecField).WithField(defField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema))
common.CheckErr(t, err, false, tc.expectedErr)
})
}
}
// test default value "" and insert ""
func TestDefaultValueVarcharEmpty(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithDefaultValue(""))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(), hp.TWithConsistencyLevel(entity.ClStrong))
coll, _ := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(schema.CollectionName))
common.CheckFieldsDefaultValue(t, map[string]interface{}{
common.DefaultVarcharFieldName: "",
}, coll.Schema)
// insert data
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData)))
hp.CollPrepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultVarcharFieldName: index.NewAutoIndex(entity.COSINE)}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
expr := fmt.Sprintf("%s == ''", common.DefaultVarcharFieldName)
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb/2, count)
// insert varchar data: ""
varcharValues := make([]string, common.DefaultNb/2)
for i := 0; i < common.DefaultNb/2; i++ {
varcharValues[i] = ""
}
columnOpt := hp.TNewColumnOptions().WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb)).
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData).TWithTextData(varcharValues))
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), columnOpt)
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb*3/2, count)
}
// test insert with nullableColumn into normal collection
func TestNullableDefaultInsertInvalid(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create normal default collection -> insert null values
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption())
_, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption())
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
pkColumn := hp.GenColumnData(common.DefaultNb, entity.FieldTypeInt64, *hp.TNewDataOption())
vecColumn := hp.GenColumnData(common.DefaultNb, entity.FieldTypeSparseVector, *hp.TNewDataOption().TWithSparseMaxLen(common.DefaultDim))
varcharColumn := hp.GenColumnData(common.DefaultNb, entity.FieldTypeVarChar, *hp.TNewDataOption().TWithValidData(validData))
_, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumn).WithColumns(vecColumn).WithColumns(varcharColumn))
common.CheckErr(t, err, false, "the length of valid_data of field(varchar) is wrong")
}
// test insert with part/all/not null -> query check
func TestNullableQuery(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithNullable(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(), hp.TWithConsistencyLevel(entity.ClStrong))
coll, _ := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(schema.CollectionName))
common.CheckFieldsNullable(t, []string{common.DefaultVarcharFieldName}, coll.Schema)
// insert data with part null, all null, all valid
partNullValidData := make([]bool, common.DefaultNb)
allNullValidData := make([]bool, common.DefaultNb)
allValidData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
partNullValidData[i] = i%3 == 0
allNullValidData[i] = false
allValidData[i] = true
}
// [o, nb] -> 2*nb/3 null
// [nb, 2*nb] -> all null
// [2*nb, 3*nb] -> all valid
// [3*nb, 4*nb] -> all valid
for i, data := range [][]bool{partNullValidData, allNullValidData, allValidData} {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*i)).
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(data).TWithStart(common.DefaultNb*i)))
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*3)))
hp.CollPrepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultVarcharFieldName: index.NewAutoIndex(entity.COSINE)}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
exprCounts := []struct {
expr string
count int64
}{
{expr: fmt.Sprintf("%s is null and (%d <= %s < %d)", common.DefaultVarcharFieldName, 0, common.DefaultInt64FieldName, common.DefaultNb), count: common.DefaultNb * 2 / 3},
{expr: fmt.Sprintf("%s is null and %d <= %s < %d", common.DefaultVarcharFieldName, common.DefaultNb, common.DefaultInt64FieldName, common.DefaultNb*2), count: common.DefaultNb},
{expr: fmt.Sprintf("%s is null and %d <= %s < %d", common.DefaultVarcharFieldName, common.DefaultNb*2, common.DefaultInt64FieldName, common.DefaultNb*3), count: 0},
{expr: fmt.Sprintf("%s is not null and %d <= %s < %d", common.DefaultVarcharFieldName, common.DefaultNb*3, common.DefaultInt64FieldName, common.DefaultNb*4), count: common.DefaultNb},
}
for _, exprCount := range exprCounts {
log.Info("exprCount", zap.String("expr", exprCount.expr), zap.Int64("count", exprCount.count))
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(exprCount.expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, exprCount.count, count)
}
}
// test insert with part/all/not default value -> query check
func TestDefaultValueQuery(t *testing.T) {
t.Skip("set defaultValue and insert with default column gets unexpected error, waiting for fix")
for _, nullable := range [2]bool{false, true} {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithDefaultValue("test").TWithNullable(nullable))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(), hp.TWithConsistencyLevel(entity.ClStrong))
coll, _ := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(schema.CollectionName))
common.CheckFieldsDefaultValue(t, map[string]interface{}{
common.DefaultVarcharFieldName: "test",
}, coll.Schema)
// insert data with part null, all null, all valid
partNullValidData := make([]bool, common.DefaultNb)
allNullValidData := make([]bool, common.DefaultNb)
allValidData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
partNullValidData[i] = i%2 == 0
allNullValidData[i] = false
allValidData[i] = true
}
// [o, nb] -> nb/2 default value
// [nb, 2*nb] -> all default value
// [2*nb, 3*nb] -> all valid
// [3*nb, 4*nb] -> all valid
for i, data := range [][]bool{partNullValidData, allNullValidData, allValidData} {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*i)).
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(data)))
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*3)))
hp.CollPrepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultVarcharFieldName: index.NewAutoIndex(entity.COSINE)}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
exprCounts := []struct {
expr string
count int64
}{
{expr: fmt.Sprintf("%s == 'test' and (%d <= %s < %d)", common.DefaultVarcharFieldName, 0, common.DefaultInt64FieldName, common.DefaultNb), count: common.DefaultNb / 2},
{expr: fmt.Sprintf("%s == 'test' and %d <= %s < %d", common.DefaultVarcharFieldName, common.DefaultNb, common.DefaultInt64FieldName, common.DefaultNb*2), count: common.DefaultNb},
{expr: fmt.Sprintf("%s == 'test' and %d <= %s < %d", common.DefaultVarcharFieldName, common.DefaultNb*2, common.DefaultInt64FieldName, common.DefaultNb*3), count: 0},
{expr: fmt.Sprintf("%s == 'test' and %d <= %s < %d", common.DefaultVarcharFieldName, common.DefaultNb*3, common.DefaultInt64FieldName, common.DefaultNb*4), count: 0},
{expr: fmt.Sprintf("%s == 'test'", common.DefaultVarcharFieldName), count: common.DefaultNb * 3 / 2},
}
for _, exprCount := range exprCounts {
log.Info("exprCount", zap.String("expr", exprCount.expr), zap.Int64("count", exprCount.count))
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(exprCount.expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, exprCount.count, count)
}
}
}
// clustering-key nullable
func TestNullableClusteringKey(t *testing.T) {
// test clustering key nullable
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithNullable(true).TWithIsClusteringKey(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(), hp.TWithConsistencyLevel(entity.ClStrong))
// insert with valid data
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
for i := 0; i < 5; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData)).
WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*i)))
}
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultVarcharFieldName: index.NewAutoIndex(entity.COSINE)}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
expr := fmt.Sprintf("%s == '1'", common.DefaultVarcharFieldName)
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 5, count)
}
// partition-key nullable
func TestDefaultValuePartitionKey(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithDefaultValue("parkey").TWithIsPartitionKey(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(),
hp.TWithConsistencyLevel(entity.ClStrong), hp.TWithNullablePartitions(3))
// insert with valid data
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
for i := 0; i < 3; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData)).
WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*i)))
}
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema).TWithFieldIndex(map[string]index.Index{common.DefaultVarcharFieldName: index.NewAutoIndex(entity.COSINE)}))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
expr := "varchar like 'parkey%'"
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb*3/2, count)
}
func TestNullableGroubBy(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithNullable(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(),
hp.TWithConsistencyLevel(entity.ClStrong))
// insert with valid data
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
if i > 200 {
validData[i] = true
}
}
for i := 0; i < 10; i++ {
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData)).
WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*i)))
}
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// nullable field as group by field
queryVec := hp.GenSearchVectors(2, common.DefaultDim, entity.FieldTypeSparseVector)
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName))
common.CheckErr(t, err, true)
common.CheckSearchResult(t, searchRes, 2, common.DefaultLimit)
}
func TestNullableSearch(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithNullable(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption().TWithEnableDynamicField(true),
hp.TWithConsistencyLevel(entity.ClStrong))
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData)))
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// search with nullable expr and output nullable / dynamic field
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector)
expr := fmt.Sprintf("%s is null", common.DefaultVarcharFieldName)
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithFilter(expr).WithOutputFields(common.DefaultVarcharFieldName, common.DefaultDynamicFieldName))
common.CheckErr(t, err, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultLimit)
common.CheckOutputFields(t, []string{common.DefaultVarcharFieldName, common.DefaultDynamicFieldName}, searchRes[0].Fields)
for _, field := range searchRes[0].Fields {
if field.Name() == common.DefaultVarcharFieldName {
for i := 0; i < field.Len(); i++ {
isNull, err := field.IsNull(i)
common.CheckErr(t, err, true)
require.True(t, isNull)
}
}
}
}
func TestDefaultValueSearch(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithDefaultValue("test").TWithNullable(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption().TWithEnableDynamicField(true),
hp.TWithConsistencyLevel(entity.ClStrong))
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData)))
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// search with nullable expr and output nullable / dynamic field
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeSparseVector)
expr := fmt.Sprintf("%s == 'test'", common.DefaultVarcharFieldName)
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithFilter(expr).WithOutputFields(common.DefaultVarcharFieldName, common.DefaultDynamicFieldName))
common.CheckErr(t, err, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultLimit)
common.CheckOutputFields(t, []string{common.DefaultVarcharFieldName, common.DefaultDynamicFieldName}, searchRes[0].Fields)
for _, field := range searchRes[0].Fields {
if field.Name() == common.DefaultVarcharFieldName {
for i := 0; i < field.Len(); i++ {
fieldData, _ := field.GetAsString(i)
require.EqualValues(t, "test", fieldData)
}
}
}
}
// test nullable fields in all scalar index
func TestNullableAutoScalarIndex(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create fields: pk + floatVec + all nullable scalar fields
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
schema := entity.NewSchema().WithName(common.GenRandomString("nullable_default_value", 5)).WithField(pkField).WithField(vecField)
// create collection with all supported nullable fields
expNullableFields := make([]string, 0)
for _, fieldType := range hp.GetAllNullableFieldType() {
nullableField := entity.NewField().WithName(common.GenRandomString("null", 5)).WithDataType(fieldType).WithNullable(true)
if fieldType == entity.FieldTypeVarChar {
nullableField.WithMaxLength(common.TestMaxLen)
}
if fieldType == entity.FieldTypeArray {
nullableField.WithElementType(entity.FieldTypeInt64).WithMaxCapacity(common.TestCapacity)
}
schema.WithField(nullableField)
expNullableFields = append(expNullableFields, nullableField.Name)
}
// create collection
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// describe collection and check nullable fields
descCollection, err := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(schema.CollectionName))
common.CheckErr(t, err, true)
common.CheckFieldsNullable(t, expNullableFields, descCollection.Schema)
// insert data with nullable column
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 1
}
columnOpt := hp.TNewColumnOptions()
for _, name := range expNullableFields {
columnOpt = columnOpt.WithColumnOption(name, hp.TNewDataOption().TWithValidData(validData))
}
for i := 0; i < 3; i++ {
columnOpt = columnOpt.WithColumnOption(common.DefaultInt64FieldName, hp.TNewDataOption().TWithStart(common.DefaultNb*i))
hp.CollPrepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), columnOpt)
}
prepare := hp.CollPrepare.FlushData(ctx, t, mc, schema.CollectionName)
// create auto scalar index for all nullable fields
indexOpt := hp.TNewIndexParams(schema)
for _, name := range expNullableFields {
indexOpt = indexOpt.TWithFieldIndex(map[string]index.Index{name: index.NewAutoIndex(entity.L2)})
}
prepare.CreateIndex(ctx, t, mc, indexOpt)
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// query with null expr
for _, nullField := range expNullableFields {
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s is null", nullField)).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb*3/2, count)
}
}
func TestNullableUpsert(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithNullable(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(), hp.TWithConsistencyLevel(entity.ClStrong))
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions())
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// query is null return empty result
expr := fmt.Sprintf("%s is null", common.DefaultVarcharFieldName)
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 0, count)
// upsert with part null, all null, all valid
partNullValidData := make([]bool, common.DefaultNb)
allNullValidData := make([]bool, common.DefaultNb)
allValidData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
partNullValidData[i] = i%2 == 0
allNullValidData[i] = false
allValidData[i] = true
}
validCount := []int{common.DefaultNb / 2, 0, common.DefaultNb}
expNullCount := []int{common.DefaultNb / 2, common.DefaultNb, 0}
pkColumn := hp.GenColumnData(common.DefaultNb, entity.FieldTypeInt64, *hp.TNewDataOption())
vecColumn := hp.GenColumnData(common.DefaultNb, entity.FieldTypeSparseVector, *hp.TNewDataOption())
for i, validData := range [][]bool{partNullValidData, allNullValidData, allValidData} {
varcharData := make([]string, validCount[i])
for j := 0; j < validCount[i]; j++ {
varcharData[j] = "aaa"
}
nullVarcharColumn, err := column.NewNullableColumnVarChar(common.DefaultVarcharFieldName, varcharData, validData)
common.CheckErr(t, err, true)
upsertRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumn).WithColumns(vecColumn).WithColumns(nullVarcharColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, common.DefaultNb, upsertRes.UpsertCount)
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, expNullCount[i], count)
}
}
func TestNullableDelete(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithNullable(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption(),
hp.TWithConsistencyLevel(entity.ClStrong))
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewColumnOptions().
WithColumnOption(common.DefaultVarcharFieldName, hp.TNewDataOption().TWithValidData(validData)))
prepare.FlushData(ctx, t, mc, schema.CollectionName)
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
// delete null
expr := fmt.Sprintf("%s is null", common.DefaultVarcharFieldName)
deleteRes, err := mc.Delete(ctx, client.NewDeleteOption(schema.CollectionName).WithExpr(expr))
common.CheckErr(t, err, true)
require.EqualValues(t, common.DefaultNb/2, deleteRes.DeleteCount)
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 0, count)
}
// TODO: test rows with nullable and default value
func TestNullableRows(t *testing.T) {
t.Skip("Waiting for rows-inserts to support nullable and defaultValue")
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
fieldsOpt := hp.TNewFieldOptions().WithFieldOption(common.DefaultVarcharFieldName, hp.TNewFieldsOption().TWithNullable(true))
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64VarcharSparseVec), fieldsOpt, hp.TNewSchemaOption().TWithEnableDynamicField(false),
hp.TWithConsistencyLevel(entity.ClStrong))
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
validData := make([]bool, common.DefaultNb)
for i := 0; i < common.DefaultNb; i++ {
validData[i] = i%2 == 0
}
rows := hp.GenInt64VarcharSparseRows(common.DefaultNb, false, false, *hp.TNewDataOption().TWithValidData(validData))
insertRes, err := mc.Insert(ctx, client.NewRowBasedInsertOption(schema.CollectionName, rows...))
common.CheckErr(t, err, true)
require.EqualValues(t, common.DefaultNb, insertRes.InsertCount)
expr := fmt.Sprintf("%s is null", common.DefaultVarcharFieldName)
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb/2, count)
}
func TestNullableVectorAllTypes(t *testing.T) {
vectorTypes := GetVectorTypes()
nullPercents := GetNullPercents()
for _, vt := range vectorTypes {
for _, nullPercent := range nullPercents {
testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent)
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create collection
collName := common.GenRandomString("nullable_vec", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
nb := 500
validData := make([]bool, nb)
validCount := 0
for i := range nb {
validData[i] = (i % 100) >= nullPercent
if validData[i] {
validCount++
}
}
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
pkToVecIdx := make(map[int64]int)
vecIdx := 0
for i := range nb {
if validData[i] {
pkToVecIdx[int64(i)] = vecIdx
vecIdx++
}
}
var vecColumn column.Column
var searchVec entity.Vector
var originalVectors interface{}
switch vt.FieldType {
case entity.FieldTypeFloatVector:
vectors := make([][]float32, validCount)
for i := range validCount {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32(i*common.DefaultDim+j) / 10000.0
}
vectors[i] = vec
}
originalVectors = vectors
vecColumn, err = column.NewNullableColumnFloatVector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.FloatVector(vectors[0])
}
case entity.FieldTypeBinaryVector:
vectors := make([][]byte, validCount)
byteDim := common.DefaultDim / 8
for i := range validCount {
vec := make([]byte, byteDim)
for j := range byteDim {
vec[j] = byte((i + j) % 256)
}
vectors[i] = vec
}
originalVectors = vectors
vecColumn, err = column.NewNullableColumnBinaryVector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.BinaryVector(vectors[0])
}
case entity.FieldTypeFloat16Vector:
vectors := make([][]byte, validCount)
for i := range validCount {
vectors[i] = common.GenFloat16Vector(common.DefaultDim)
}
originalVectors = vectors
vecColumn, err = column.NewNullableColumnFloat16Vector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.Float16Vector(vectors[0])
}
case entity.FieldTypeBFloat16Vector:
vectors := make([][]byte, validCount)
for i := range validCount {
vectors[i] = common.GenBFloat16Vector(common.DefaultDim)
}
originalVectors = vectors
vecColumn, err = column.NewNullableColumnBFloat16Vector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.BFloat16Vector(vectors[0])
}
case entity.FieldTypeInt8Vector:
vectors := make([][]int8, validCount)
for i := range validCount {
vec := make([]int8, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = int8((i + j) % 127)
}
vectors[i] = vec
}
originalVectors = vectors
vecColumn, err = column.NewNullableColumnInt8Vector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.Int8Vector(vectors[0])
}
case entity.FieldTypeSparseVector:
vectors := make([]entity.SparseEmbedding, validCount)
for i := range validCount {
positions := []uint32{0, uint32(i + 1), uint32(i + 1000)}
values := []float32{1.0, float32(i+1) / 1000.0, 0.1}
vectors[i], err = entity.NewSliceSparseEmbedding(positions, values)
common.CheckErr(t, err, true)
}
originalVectors = vectors
vecColumn, err = column.NewNullableColumnSparseFloatVector("vector", vectors, validData)
if validCount > 0 {
searchVec = vectors[0]
}
}
common.CheckErr(t, err, true)
_ = originalVectors
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, vecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
if validCount > 0 {
var vecIndex index.Index
switch vt.FieldType {
case entity.FieldTypeBinaryVector:
vecIndex = index.NewGenericIndex("vector", map[string]string{
index.MetricTypeKey: string(entity.JACCARD),
index.IndexTypeKey: "BIN_FLAT",
})
case entity.FieldTypeInt8Vector:
vecIndex = index.NewGenericIndex("vector", map[string]string{
index.MetricTypeKey: string(entity.COSINE),
index.IndexTypeKey: "HNSW",
"M": "16",
"efConstruction": "200",
})
case entity.FieldTypeSparseVector:
vecIndex = index.NewSparseInvertedIndex(entity.IP, 0.1)
default:
vecIndex = index.NewFlatIndex(entity.L2)
}
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
require.True(t, len(searchIDs) > 0, "search should return results")
expectedTopK := 10
if validCount < expectedTopK {
expectedTopK = validCount
}
require.EqualValues(t, expectedTopK, len(searchIDs), "search should return expected number of results")
for _, id := range searchIDs {
require.True(t, id >= 0 && id < int64(nb), "search result ID %d should be in range [0, %d)", id, nb)
}
verifyVectorData := func(queryResult client.ResultSet, context string) {
pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64)
vecCol := queryResult.GetColumn("vector")
for i := 0; i < queryResult.ResultCount; i++ {
pk, _ := pkCol.GetAsInt64(i)
isNull, _ := vecCol.IsNull(i)
if origIdx, ok := pkToVecIdx[pk]; ok {
require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
switch vt.FieldType {
case entity.FieldTypeFloatVector:
vectors := originalVectors.([][]float32)
queriedVec := []float32(vecData.(entity.FloatVector))
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.InDelta(t, origVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeInt8Vector:
vectors := originalVectors.([][]int8)
queriedVec := []int8(vecData.(entity.Int8Vector))
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector element %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeBinaryVector:
vectors := originalVectors.([][]byte)
queriedVec := []byte(vecData.(entity.BinaryVector))
byteDim := common.DefaultDim / 8
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
origVec := vectors[origIdx]
for j := range origVec {
require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector byte %d should match for pk %d", context, j, pk)
}
case entity.FieldTypeFloat16Vector:
queriedVec := []byte(vecData.(entity.Float16Vector))
byteDim := common.DefaultDim * 2
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
case entity.FieldTypeBFloat16Vector:
queriedVec := []byte(vecData.(entity.BFloat16Vector))
byteDim := common.DefaultDim * 2
require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk)
case entity.FieldTypeSparseVector:
vectors := originalVectors.([]entity.SparseEmbedding)
queriedVec := vecData.(entity.SparseEmbedding)
origVec := vectors[origIdx]
require.EqualValues(t, origVec.Len(), queriedVec.Len(), "%s: sparse vector length should match for pk %d", context, pk)
for j := 0; j < origVec.Len(); j++ {
origPos, origVal, _ := origVec.Get(j)
queriedPos, queriedVal, _ := queriedVec.Get(j)
require.EqualValues(t, origPos, queriedPos, "%s: sparse vector position %d should match for pk %d", context, j, pk)
require.InDelta(t, origVal, queriedVal, 1e-6, "%s: sparse vector value %d should match for pk %d", context, j, pk)
}
}
} else {
require.True(t, isNull, "%s: vector should be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk)
}
}
}
if len(searchIDs) > 0 {
searchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(searchIDs))).WithOutputFields("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, len(searchIDs), searchQueryRes.ResultCount, "query by search IDs should return all IDs")
verifyVectorData(searchQueryRes, "Search results")
}
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 10").WithOutputFields("vector"))
common.CheckErr(t, err, true)
expectedQueryCount := 10
require.EqualValues(t, expectedQueryCount, queryRes.ResultCount, "query should return expected count")
verifyVectorData(queryRes, "Query int64 < 10")
searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).
WithANNSField("vector").WithFilter("int64 < 100").WithOutputFields("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
filteredIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
for _, id := range filteredIDs {
require.True(t, id < 100, "filtered search should only return IDs < 100, got %d", id)
}
hybridSearchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(filteredIDs))).WithOutputFields("vector"))
common.CheckErr(t, err, true)
verifyVectorData(hybridSearchQueryRes, "Hybrid search results")
countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
totalCount, err := countRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, nb, totalCount, "total count should equal inserted rows")
}
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
}
func TestNullableVectorWithScalarFilter(t *testing.T) {
vectorTypes := GetVectorTypes()
nullPercents := GetNullPercents()
for _, vt := range vectorTypes {
for _, nullPercent := range nullPercents {
testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent)
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_filter", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
tagField := entity.NewField().WithName("tag").WithDataType(entity.FieldTypeVarChar).WithMaxLength(100).WithNullable(true)
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField).WithField(tagField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
nb := 500
testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector")
// tag field: 50% null (even rows are valid)
tagValidData := make([]bool, nb)
tagValidCount := 0
for i := range nb {
tagValidData[i] = i%2 == 0
if tagValidData[i] {
tagValidCount++
}
}
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
tagData := make([]string, tagValidCount)
for i := range tagValidCount {
tagData[i] = fmt.Sprintf("tag_%d", i)
}
tagColumn, err := column.NewNullableColumnVarChar("tag", tagData, tagValidData)
common.CheckErr(t, err, true)
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn, tagColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
if testData.ValidCount > 0 {
vecIndex := CreateNullableVectorIndex(vt)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// Query with scalar filter: tag is not null and int64 < 50
// int64 < 50 => 50 rows (pk 0-49)
// tag is not null => even rows only (pk 0, 2, 4, ..., 48) => 25 rows
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("tag is not null and int64 < 50").WithOutputFields("vector", "tag"))
common.CheckErr(t, err, true)
require.EqualValues(t, 25, queryRes.ResultCount, "query should return 25 rows with tag not null and int64 < 50")
VerifyNullableVectorData(t, vt, queryRes, testData.PkToVecIdx, testData.OriginalVectors, "Query with tag filter")
}
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
}
func TestNullableVectorDelete(t *testing.T) {
vectorTypes := GetVectorTypes()
nullPercents := GetNullPercents()
for _, vt := range vectorTypes {
for _, nullPercent := range nullPercents {
testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent)
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_del", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
nb := 100
testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector")
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
if testData.ValidCount > 0 {
vecIndex := CreateNullableVectorIndex(vt)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// Delete first 25 rows and last 25 rows
delRes, err := mc.Delete(ctx, client.NewDeleteOption(collName).WithExpr("int64 < 25"))
common.CheckErr(t, err, true)
require.EqualValues(t, 25, delRes.DeleteCount)
delRes, err = mc.Delete(ctx, client.NewDeleteOption(collName).WithExpr("int64 >= 75"))
common.CheckErr(t, err, true)
require.EqualValues(t, 25, delRes.DeleteCount)
// Verify remaining count
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
count, err := queryRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, 50, count, "remaining count should be 100 - 25 - 25 = 50")
// Verify deleted rows don't exist
queryDeletedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 25").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
deletedCount, err := queryDeletedRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, 0, deletedCount, "deleted rows should not exist")
queryDeletedValidRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 75").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
deletedValidCount, err := queryDeletedValidRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, 0, deletedValidCount, "deleted valid vector rows should not exist")
// Verify remaining rows with vector data
queryValidRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 25 and int64 < 75").WithOutputFields("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, queryValidRes.ResultCount, "should have 50 remaining rows")
VerifyNullableVectorData(t, vt, queryValidRes, testData.PkToVecIdx, testData.OriginalVectors, "Remaining vector rows")
}
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
}
func TestNullableVectorUpsert(t *testing.T) {
autoIDOptions := []bool{false, true}
for _, autoID := range autoIDOptions {
testName := fmt.Sprintf("AutoID=%v", autoID)
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_ups", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
if autoID {
pkField.AutoID = true
}
vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim).WithNullable(true)
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// insert initial data with 50% null
nb := 100
nullPercent := 50
validData := make([]bool, nb)
validCount := 0
for i := range nb {
validData[i] = (i % 100) >= nullPercent
if validData[i] {
validCount++
}
}
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
pkToVecIdx := make(map[int64]int)
vecIdx := 0
for i := range nb {
if validData[i] {
pkToVecIdx[int64(i)] = vecIdx
vecIdx++
}
}
vectors := make([][]float32, validCount)
for i := range validCount {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32(i*common.DefaultDim+j) / 10000.0
}
vectors[i] = vec
}
vecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, vectors, validData)
common.CheckErr(t, err, true)
var insertRes client.InsertResult
if autoID {
insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(vecColumn))
} else {
insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn))
}
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
var actualPkData []int64
if autoID {
insertedIDs := insertRes.IDs.(*column.ColumnInt64)
actualPkData = insertedIDs.Data()
require.EqualValues(t, nb, len(actualPkData), "inserted PK count should match")
} else {
actualPkData = pkData
}
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, common.DefaultFloatVecFieldName, index.NewFlatIndex(entity.L2)))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// upsert: change first 25 rows (originally null) to valid, change rows 50-74 (originally valid) to null
upsertNb := 50
upsertValidData := make([]bool, upsertNb)
for i := range upsertNb {
upsertValidData[i] = i < 25
}
upsertPkData := make([]int64, upsertNb)
for i := range upsertNb {
if i < 25 {
upsertPkData[i] = actualPkData[i]
} else {
upsertPkData[i] = actualPkData[i+25]
}
}
upsertPkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsertPkData)
upsertVectors := make([][]float32, 25)
for i := range 25 {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32((i+100)*common.DefaultDim+j) / 10000.0
}
upsertVectors[i] = vec
}
upsertVecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, upsertVectors, upsertValidData)
common.CheckErr(t, err, true)
upsertRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsertPkColumn, upsertVecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, upsertNb, upsertRes.UpsertCount)
var upsertedPks []int64
if autoID {
upsertedIDs := upsertRes.IDs.(*column.ColumnInt64)
upsertedPks = upsertedIDs.Data()
require.EqualValues(t, upsertNb, len(upsertedPks), "upserted PK count should match")
} else {
upsertedPks = upsertPkData
}
expectedVectorMap := make(map[int64][]float32)
for i := 0; i < 25; i++ {
expectedVectorMap[upsertedPks[i]] = upsertVectors[i]
}
for i := 25; i < 50; i++ {
expectedVectorMap[upsertedPks[i]] = nil
}
for i := 25; i < 50; i++ {
expectedVectorMap[actualPkData[i]] = nil
}
for i := 75; i < 100; i++ {
vecIdx := i - 50
expectedVectorMap[actualPkData[i]] = vectors[vecIdx]
}
time.Sleep(10 * time.Second)
flushTask, err = mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
err = mc.ReleaseCollection(ctx, client.NewReleaseCollectionOption(collName))
common.CheckErr(t, err, true)
loadTask, err = mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
expectedValidCount := 50
searchVec := entity.FloatVector(common.GenFloatVector(common.DefaultDim))
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{searchVec}).WithANNSField(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, expectedValidCount, len(searchIDs), "search should return all 50 valid vectors")
verifyVectorData := func(queryResult client.ResultSet, context string) {
pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64)
vecCol := queryResult.GetColumn(common.DefaultFloatVecFieldName).(*column.ColumnFloatVector)
for i := 0; i < queryResult.ResultCount; i++ {
pk, _ := pkCol.GetAsInt64(i)
isNull, _ := vecCol.IsNull(i)
expectedVec, exists := expectedVectorMap[pk]
require.True(t, exists, "%s: unexpected PK %d in query results", context, pk)
if expectedVec != nil {
require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
queriedVec := []float32(vecData.(entity.FloatVector))
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
for j := range expectedVec {
require.InDelta(t, expectedVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk)
}
} else {
require.True(t, isNull, "%s: vector should be null for pk %d", context, pk)
vecData, _ := vecCol.Get(i)
require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk)
}
}
}
searchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(searchIDs))).WithOutputFields(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, true)
verifyVectorData(searchQueryRes, "All valid vectors after upsert")
upsertedToValidPKs := upsertedPks[0:25]
queryUpsertedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(upsertedToValidPKs))).WithOutputFields(common.DefaultFloatVecFieldName))
common.CheckErr(t, err, true)
require.EqualValues(t, 25, queryUpsertedRes.ResultCount, "should have 25 rows for upserted to valid")
verifyVectorData(queryUpsertedRes, "Upserted valid rows")
countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
totalCount, err := countRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, nb, totalCount, "total count after upsert should still be %d", nb)
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
func TestNullableVectorAllNull(t *testing.T) {
vectorTypes := GetVectorTypes()
for _, vt := range vectorTypes {
testName := fmt.Sprintf("%s_100%%null", vt.Name)
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_all", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// Generate test data with 100% null (nullPercent = 100)
nb := 100
testData := GenerateNullableVectorTestData(t, vt, nb, 100, "vector")
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
// insert
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
// flush
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
// create index and load
vecIndex := CreateNullableVectorIndex(vt)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// Generate a search vector (won't match anything since all are null)
var searchVec entity.Vector
switch vt.FieldType {
case entity.FieldTypeFloatVector:
searchVec = entity.FloatVector(common.GenFloatVector(common.DefaultDim))
case entity.FieldTypeBinaryVector:
searchVec = entity.BinaryVector(make([]byte, common.DefaultDim/8))
case entity.FieldTypeFloat16Vector:
searchVec = entity.Float16Vector(common.GenFloat16Vector(common.DefaultDim))
case entity.FieldTypeBFloat16Vector:
searchVec = entity.BFloat16Vector(common.GenBFloat16Vector(common.DefaultDim))
case entity.FieldTypeInt8Vector:
vec := make([]int8, common.DefaultDim)
searchVec = entity.Int8Vector(vec)
case entity.FieldTypeSparseVector:
searchVec, _ = entity.NewSliceSparseEmbedding([]uint32{0}, []float32{1.0})
}
// search should return empty results since all vectors are null (not searchable)
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, 0, len(searchIDs), "search should return empty results for all-null vectors")
// query should return all rows
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
count, err := queryRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, nb, count, "query should return all %d rows even with 100%% null vectors", nb)
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
func TestNullableVectorMultiFields(t *testing.T) {
vectorTypes := GetVectorTypes()
for _, vt := range vectorTypes {
testName := vt.Name
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_multi", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField1 := entity.NewField().WithName("vec1").WithDataType(vt.FieldType).WithNullable(true)
vecField2 := entity.NewField().WithName("vec2").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField1 = vecField1.WithDim(common.DefaultDim)
vecField2 = vecField2.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField1).WithField(vecField2)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// generate data: vec1 has 70% valid (first 30 per 100 are invalid), vec2 has 30% valid (first 70 per 100 are invalid)
nb := 100
nullPercent1 := 30 // vec1: 30% null
nullPercent2 := 70 // vec2: 70% null
// Generate test data for both vector fields
testData1 := GenerateNullableVectorTestData(t, vt, nb, nullPercent1, "vec1")
testData2 := GenerateNullableVectorTestData(t, vt, nb, nullPercent2, "vec2")
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
// insert
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData1.VecColumn, testData2.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
// flush
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
// create indexes for both vector fields
vecIndex := CreateNullableVectorIndex(vt)
indexTask1, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vec1", vecIndex))
common.CheckErr(t, err, true)
err = indexTask1.Await(ctx)
common.CheckErr(t, err, true)
indexTask2, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vec2", vecIndex))
common.CheckErr(t, err, true)
err = indexTask2.Await(ctx)
common.CheckErr(t, err, true)
// load
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// search on vec1
searchRes1, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{testData1.SearchVec}).WithANNSField("vec1").WithOutputFields("vec1"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes1))
searchIDs1 := searchRes1[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, 10, len(searchIDs1), "search on vec1 should return 10 results")
for _, id := range searchIDs1 {
_, ok := testData1.PkToVecIdx[id]
require.True(t, ok, "search on vec1 should only return rows where vec1 is valid, got pk %d", id)
}
// search on vec2
searchRes2, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{testData2.SearchVec}).WithANNSField("vec2").WithOutputFields("vec2"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes2))
searchIDs2 := searchRes2[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, 10, len(searchIDs2), "search on vec2 should return 10 results")
for _, id := range searchIDs2 {
_, ok := testData2.PkToVecIdx[id]
require.True(t, ok, "search on vec2 should only return rows where vec2 is valid, got pk %d", id)
}
// query and verify - rows 0-29 both null
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 30").WithOutputFields("vec1", "vec2"))
common.CheckErr(t, err, true)
VerifyNullableVectorDataWithFieldName(t, vt, queryRes, testData1.PkToVecIdx, testData1.OriginalVectors, "vec1", "query0-29 vec1")
VerifyNullableVectorDataWithFieldName(t, vt, queryRes, testData2.PkToVecIdx, testData2.OriginalVectors, "vec2", "query0-29 vec2")
// query rows 30-69: vec1 valid, vec2 null
queryMixedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 30 AND int64 < 70").WithOutputFields("vec1", "vec2"))
common.CheckErr(t, err, true)
VerifyNullableVectorDataWithFieldName(t, vt, queryMixedRes, testData1.PkToVecIdx, testData1.OriginalVectors, "vec1", "query30-69 vec1")
VerifyNullableVectorDataWithFieldName(t, vt, queryMixedRes, testData2.PkToVecIdx, testData2.OriginalVectors, "vec2", "query30-69 vec2")
// query rows 70-99: both valid
queryBothValidRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 70").WithOutputFields("vec1", "vec2"))
common.CheckErr(t, err, true)
VerifyNullableVectorDataWithFieldName(t, vt, queryBothValidRes, testData1.PkToVecIdx, testData1.OriginalVectors, "vec1", "query70-99 vec1")
VerifyNullableVectorDataWithFieldName(t, vt, queryBothValidRes, testData2.PkToVecIdx, testData2.OriginalVectors, "vec2", "query70-99 vec2")
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
func TestNullableVectorPaginatedQuery(t *testing.T) {
vectorTypes := GetVectorTypes()
nullPercents := GetNullPercents()
for _, vt := range vectorTypes {
for _, nullPercent := range nullPercents {
testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent)
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_page", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
nb := 200
testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector")
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
// insert
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
if testData.ValidCount > 0 {
// create index and load
vecIndex := CreateNullableVectorIndex(vt)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// Test pagination: page1 offset=0, limit=50
page1Res, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("").WithOutputFields("vector").WithOffset(0).WithLimit(50))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, page1Res.ResultCount, "page 1 should return 50 rows")
VerifyNullableVectorData(t, vt, page1Res, testData.PkToVecIdx, testData.OriginalVectors, "page1")
// Test pagination: page2 offset=50, limit=50
page2Res, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("").WithOutputFields("vector").WithOffset(50).WithLimit(50))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, page2Res.ResultCount, "page 2 should return 50 rows")
VerifyNullableVectorData(t, vt, page2Res, testData.PkToVecIdx, testData.OriginalVectors, "page2")
// Test pagination: page3 offset=100, limit=50
page3Res, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("").WithOutputFields("vector").WithOffset(100).WithLimit(50))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, page3Res.ResultCount, "page 3 should return 50 rows")
VerifyNullableVectorData(t, vt, page3Res, testData.PkToVecIdx, testData.OriginalVectors, "page3")
// Test mixed query with filter
mixedPageRes, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("int64 >= 40 and int64 < 60").
WithOutputFields("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 20, mixedPageRes.ResultCount, "mixed query should return 20 rows")
VerifyNullableVectorData(t, vt, mixedPageRes, testData.PkToVecIdx, testData.OriginalVectors, "mixed query")
}
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
}
func TestNullableVectorMultiPartitions(t *testing.T) {
vectorTypes := GetVectorTypes()
for _, vt := range vectorTypes {
testName := vt.Name
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_part", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// create partitions
partitions := []string{"partition_a", "partition_b", "partition_c"}
for _, p := range partitions {
err = mc.CreatePartition(ctx, client.NewCreatePartitionOption(collName, p))
common.CheckErr(t, err, true)
}
// insert data into each partition with different null ratios
nbPerPartition := 100
nullRatios := []int{0, 30, 50} // 0%, 30%, 50% null for each partition
// Store all test data and mappings for verification
allPkToVecIdx := make(map[int64]int)
var allOriginalVectors interface{}
var firstSearchVec entity.Vector
globalVecIdx := 0
for i, partition := range partitions {
nullRatio := nullRatios[i]
testData := GenerateNullableVectorTestData(t, vt, nbPerPartition, nullRatio, "vector")
// pk column with unique ids per partition
pkData := make([]int64, nbPerPartition)
for j := range nbPerPartition {
pkData[j] = int64(i*nbPerPartition + j)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
for j := range nbPerPartition {
if testData.ValidData[j] {
allPkToVecIdx[pkData[j]] = globalVecIdx
globalVecIdx++
}
}
// Accumulate original vectors for verification
switch vt.FieldType {
case entity.FieldTypeFloatVector:
if allOriginalVectors == nil {
allOriginalVectors = make([][]float32, 0)
}
allOriginalVectors = append(allOriginalVectors.([][]float32), testData.OriginalVectors.([][]float32)...)
case entity.FieldTypeBinaryVector:
if allOriginalVectors == nil {
allOriginalVectors = make([][]byte, 0)
}
allOriginalVectors = append(allOriginalVectors.([][]byte), testData.OriginalVectors.([][]byte)...)
case entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector:
if allOriginalVectors == nil {
allOriginalVectors = make([][]byte, 0)
}
allOriginalVectors = append(allOriginalVectors.([][]byte), testData.OriginalVectors.([][]byte)...)
case entity.FieldTypeInt8Vector:
if allOriginalVectors == nil {
allOriginalVectors = make([][]int8, 0)
}
allOriginalVectors = append(allOriginalVectors.([][]int8), testData.OriginalVectors.([][]int8)...)
case entity.FieldTypeSparseVector:
if allOriginalVectors == nil {
allOriginalVectors = make([]entity.SparseEmbedding, 0)
}
allOriginalVectors = append(allOriginalVectors.([]entity.SparseEmbedding), testData.OriginalVectors.([]entity.SparseEmbedding)...)
}
// Save first search vector (from partition_a with 0% null)
if i == 0 && testData.SearchVec != nil {
firstSearchVec = testData.SearchVec
}
// insert into partition
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn).WithPartition(partition))
common.CheckErr(t, err, true)
require.EqualValues(t, nbPerPartition, insertRes.InsertCount)
}
// flush
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
// create index and load
vecIndex := CreateNullableVectorIndex(vt)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// search in specific partition - verify all results from partition_a
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{firstSearchVec}).
WithANNSField("vector").
WithPartitions("partition_a"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, 10, len(searchIDs), "search in partition_a should return 10 results")
// partition_a has 0% null, so all 100 vectors are valid, IDs should be 0-99
for _, id := range searchIDs {
require.True(t, id >= 0 && id < int64(nbPerPartition), "partition_a IDs should be in range [0, %d), got %d", nbPerPartition, id)
// Verify all search results have valid vectors
_, ok := allPkToVecIdx[id]
require.True(t, ok, "search result pk %d should have a valid vector", id)
}
// search across all partitions - should return results from any partition
searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{firstSearchVec}).
WithANNSField("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
allSearchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, 50, len(allSearchIDs), "search across all partitions should return 50 results")
// Verify all search results have valid vectors
for _, id := range allSearchIDs {
_, ok := allPkToVecIdx[id]
require.True(t, ok, "all partitions search result pk %d should have a valid vector", id)
}
// query each partition to verify counts
expectedCounts := []int64{100, 100, 100} // total rows in each partition
for i, partition := range partitions {
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("").
WithOutputFields("count(*)").
WithPartitions(partition))
common.CheckErr(t, err, true)
count, err := queryRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, expectedCounts[i], count, "partition %s should have %d rows", partition, expectedCounts[i])
}
// query with vector output from specific partition - partition_a (0% null)
queryVecRes, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("int64 < 10").
WithOutputFields("vector").
WithPartitions("partition_a"))
common.CheckErr(t, err, true)
require.EqualValues(t, 10, queryVecRes.ResultCount, "query partition_a with int64 < 10 should return 10 rows")
VerifyNullableVectorData(t, vt, queryVecRes, allPkToVecIdx, allOriginalVectors, "query partition_a int64 < 10")
// query partition_b which has 30% null (rows 100-129 are null, 130-199 are valid)
queryPartBRes, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("int64 >= 100 AND int64 < 150").
WithOutputFields("vector").
WithPartitions("partition_b"))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, queryPartBRes.ResultCount, "query partition_b with 100 <= int64 < 150 should return 50 rows")
VerifyNullableVectorData(t, vt, queryPartBRes, allPkToVecIdx, allOriginalVectors, "query partition_b int64 100-149")
// query partition_c which has 50% null (rows 200-249 are null, 250-299 are valid)
queryPartCRes, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter("int64 >= 200 AND int64 < 260").
WithOutputFields("vector").
WithPartitions("partition_c"))
common.CheckErr(t, err, true)
require.EqualValues(t, 60, queryPartCRes.ResultCount, "query partition_c with 200 <= int64 < 260 should return 60 rows")
VerifyNullableVectorData(t, vt, queryPartCRes, allPkToVecIdx, allOriginalVectors, "query partition_c int64 200-259")
// verify total count across all partitions
totalCountRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
totalCount, err := totalCountRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, nbPerPartition*3, totalCount, "total count should be %d", nbPerPartition*3)
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
func TestNullableVectorCompaction(t *testing.T) {
vectorTypes := GetVectorTypes()
for _, vt := range vectorTypes {
testName := vt.Name
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_comp", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// insert data in multiple batches to create multiple segments
nb := 200
nullPercent := 30
// Store all vectors and mappings for verification
allPkToVecIdx := make(map[int64]int)
var allOriginalVectors interface{}
var searchVec entity.Vector
globalVecIdx := 0
// batch 1: generate test data
testData1 := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector")
pkData1 := make([]int64, nb)
for i := range nb {
pkData1[i] = int64(i)
}
pkColumn1 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData1)
for i := range nb {
if testData1.ValidData[i] {
allPkToVecIdx[pkData1[i]] = globalVecIdx
globalVecIdx++
}
}
// Store original vectors
allOriginalVectors = testData1.OriginalVectors
searchVec = testData1.SearchVec
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn1, testData1.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
// flush to create segment
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
// wait for rate limiter reset before next flush (rate=0.1 means 1 flush per 10s)
time.Sleep(10 * time.Second)
testData2 := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector")
pkData2 := make([]int64, nb)
for i := range nb {
pkData2[i] = int64(nb + i)
}
pkColumn2 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData2)
for i := range nb {
if testData2.ValidData[i] {
allPkToVecIdx[pkData2[i]] = globalVecIdx
globalVecIdx++
}
}
// Accumulate original vectors for verification
switch vt.FieldType {
case entity.FieldTypeFloatVector:
allOriginalVectors = append(allOriginalVectors.([][]float32), testData2.OriginalVectors.([][]float32)...)
case entity.FieldTypeBinaryVector:
allOriginalVectors = append(allOriginalVectors.([][]byte), testData2.OriginalVectors.([][]byte)...)
case entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector:
allOriginalVectors = append(allOriginalVectors.([][]byte), testData2.OriginalVectors.([][]byte)...)
case entity.FieldTypeInt8Vector:
allOriginalVectors = append(allOriginalVectors.([][]int8), testData2.OriginalVectors.([][]int8)...)
case entity.FieldTypeSparseVector:
allOriginalVectors = append(allOriginalVectors.([]entity.SparseEmbedding), testData2.OriginalVectors.([]entity.SparseEmbedding)...)
}
insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn2, testData2.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
// flush to create another segment
flushTask, err = mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
// create index and load
vecIndex := CreateNullableVectorIndex(vt)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// delete some data (mix of valid and null vectors) - first 50 rows from batch 1
delRes, err := mc.Delete(ctx, client.NewDeleteOption(collName).WithExpr("int64 < 50"))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, delRes.DeleteCount, "should delete 50 rows")
// trigger manual compaction
compactID, err := mc.Compact(ctx, client.NewCompactOption(collName))
common.CheckErr(t, err, true)
t.Logf("Compaction started with ID: %d", compactID)
// wait for compaction to complete
for i := 0; i < 60; i++ {
state, err := mc.GetCompactionState(ctx, client.NewGetCompactionStateOption(compactID))
common.CheckErr(t, err, true)
if state == entity.CompactionStateCompleted {
t.Log("Compaction completed")
break
}
time.Sleep(time.Second)
}
// verify remaining count: 400 total - 50 deleted = 350
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
count, err := queryRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, nb*2-50, count, "remaining count should be 400 - 50 = 350")
// verify deleted rows are gone
queryDeletedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 50").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
deletedCount, err := queryDeletedRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, 0, deletedCount, "deleted rows should not exist")
// search should still work - verify returns results
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, 10, len(searchIDs), "search should return 10 results")
// All search results should have IDs >= 50 (since we deleted pk < 50) and have valid vectors
for _, id := range searchIDs {
require.True(t, id >= 50, "search results should not include deleted IDs, got %d", id)
_, ok := allPkToVecIdx[id]
require.True(t, ok, "search result pk %d should have a valid vector", id)
}
// query with output vector field - verify remaining valid vectors in batch 1
queryRes, err = mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 50 and int64 < 100").WithOutputFields("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, queryRes.ResultCount, "should have 50 rows in range [50, 100)")
VerifyNullableVectorData(t, vt, queryRes, allPkToVecIdx, allOriginalVectors, "query batch1 remaining 50-99")
queryMixedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 200 and int64 < 250").WithOutputFields("vector"))
common.CheckErr(t, err, true)
require.EqualValues(t, 50, queryMixedRes.ResultCount, "should have 50 rows in range [200, 250)")
VerifyNullableVectorData(t, vt, queryMixedRes, allPkToVecIdx, allOriginalVectors, "query batch2 200-249")
queryBatch2CountRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 200").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
batch2Count, err := queryBatch2CountRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, nb, batch2Count, "batch 2 should have all %d rows intact", nb)
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
func TestNullableVectorAddField(t *testing.T) {
vectorTypes := GetVectorTypes()
for _, vt := range vectorTypes {
t.Run(vt.Name, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_add", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
origVecField := entity.NewField().WithName("orig_vec").WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(origVecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
nb := 100
pkData1 := make([]int64, nb)
for i := range nb {
pkData1[i] = int64(i)
}
pkColumn1 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData1)
origVecData1 := make([][]float32, nb)
for i := range nb {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32(i*common.DefaultDim+j) / 10000.0
}
origVecData1[i] = vec
}
origVecColumn1 := column.NewColumnFloatVector("orig_vec", common.DefaultDim, origVecData1)
insertRes1, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn1, origVecColumn1))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes1.InsertCount)
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
// wait for rate limiter reset before next flush (rate=0.1 means 1 flush per 10s)
time.Sleep(10 * time.Second)
// SparseVector does not need dim, but other vectors do
newVecField := entity.NewField().WithName("new_vec").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
newVecField = newVecField.WithDim(common.DefaultDim)
}
err = mc.AddCollectionField(ctx, client.NewAddCollectionFieldOption(collName, newVecField))
common.CheckErr(t, err, true)
// verify schema updated
coll, err := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(collName))
common.CheckErr(t, err, true)
require.EqualValues(t, 3, len(coll.Schema.Fields), "should have 3 fields after adding new vector field")
nullPercent := 30 // 30% null
testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "new_vec")
pkData2 := make([]int64, nb)
for i := range nb {
pkData2[i] = int64(nb + i) // pk starts from nb
}
pkColumn2 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData2)
origVecData2 := make([][]float32, nb)
for i := range nb {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32((nb+i)*common.DefaultDim+j) / 10000.0
}
origVecData2[i] = vec
}
origVecColumn2 := column.NewColumnFloatVector("orig_vec", common.DefaultDim, origVecData2)
insertRes2, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn2, origVecColumn2, testData.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes2.InsertCount)
flushTask2, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask2.Await(ctx)
common.CheckErr(t, err, true)
// create indexes
origVecIndex := index.NewFlatIndex(entity.L2)
indexTask1, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "orig_vec", origVecIndex))
common.CheckErr(t, err, true)
err = indexTask1.Await(ctx)
common.CheckErr(t, err, true)
newVecIndex := CreateNullableVectorIndexWithFieldName(vt, "new_vec")
indexTask2, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "new_vec", newVecIndex))
common.CheckErr(t, err, true)
err = indexTask2.Await(ctx)
common.CheckErr(t, err, true)
// load collection
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// verify total count
countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
totalCount, err := countRes.Fields[0].GetAsInt64(0)
common.CheckErr(t, err, true)
require.EqualValues(t, nb*2, totalCount, "total count should be %d", nb*2)
searchVec := entity.FloatVector(origVecData1[0])
searchRes1, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("orig_vec").WithOutputFields("new_vec"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes1))
require.EqualValues(t, 10, len(searchRes1[0].IDs.(*column.ColumnInt64).Data()), "search on orig_vec should return 10 results")
if testData.SearchVec != nil {
searchRes2, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{testData.SearchVec}).WithANNSField("new_vec").WithOutputFields("new_vec"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes2))
searchIDs2 := searchRes2[0].IDs.(*column.ColumnInt64).Data()
require.EqualValues(t, 10, len(searchIDs2), "search on new_vec should return 10 results")
for _, id := range searchIDs2 {
require.True(t, id >= int64(nb), "search on new_vec should only return batch 2 rows, got pk %d", id)
_, ok := testData.PkToVecIdx[id-int64(nb)]
require.True(t, ok, "search result pk %d should have valid new_vec", id)
}
}
queryRes1, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 100").WithOutputFields("new_vec"))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, queryRes1.ResultCount, "should have %d rows in batch 1", nb)
newVecCol1 := queryRes1.GetColumn("new_vec")
for i := 0; i < queryRes1.ResultCount; i++ {
isNull, _ := newVecCol1.IsNull(i)
require.True(t, isNull, "batch 1 rows should have null new_vec")
}
queryRes2, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 100").WithOutputFields("new_vec"))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, queryRes2.ResultCount, "should have %d rows in batch 2", nb)
pkToVecIdx2 := make(map[int64]int)
for pk, idx := range testData.PkToVecIdx {
// original PkToVecIdx uses pk 0..nb-1, need to map to nb..2*nb-1
pkToVecIdx2[pk+int64(nb)] = idx
}
VerifyNullableVectorDataWithFieldName(t, vt, queryRes2, pkToVecIdx2, testData.OriginalVectors, "new_vec", "query batch 2")
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
func TestNullableVectorRangeSearch(t *testing.T) {
vectorTypes := GetVectorTypes()
for _, vt := range vectorTypes {
t.Run(vt.Name, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// create collection
collName := common.GenRandomString("nullable_vec_range", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
// generate data with 30% null
nb := 500
nullPercent := 30
testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector")
// pk column
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
// insert
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
// flush
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
// create index with appropriate metric type
var vecIndex index.Index
switch vt.FieldType {
case entity.FieldTypeSparseVector:
vecIndex = index.NewSparseInvertedIndex(entity.IP, 0.1)
case entity.FieldTypeBinaryVector:
// BinaryVector uses Hamming distance
vecIndex = index.NewBinFlatIndex(entity.HAMMING)
case entity.FieldTypeInt8Vector:
// Int8Vector uses COSINE metric
vecIndex = index.NewHNSWIndex(entity.COSINE, 8, 96)
default:
// FloatVector, Float16Vector, BFloat16Vector use L2 metric
vecIndex = index.NewHNSWIndex(entity.L2, 8, 96)
}
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
// load
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
if testData.SearchVec != nil {
var searchRes []client.ResultSet
switch vt.FieldType {
case entity.FieldTypeSparseVector:
// For sparse vector, use IP metric with radius and range_filter
// IP metric: higher is better, range is [radius, range_filter]
annParams := index.NewSparseAnnParam()
annParams.WithRadius(0)
annParams.WithRangeFilter(100)
annParams.WithDropRatio(0.2)
searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}).
WithANNSField("vector").WithAnnParam(annParams).WithOutputFields("vector"))
case entity.FieldTypeBinaryVector:
// For binary vector, use Hamming distance
// Hamming distance: smaller is better (number of different bits), range is [range_filter, radius]
// With dim=128, max Hamming distance is 128
searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}).
WithANNSField("vector").WithSearchParam("radius", "128").WithSearchParam("range_filter", "0").WithOutputFields("vector"))
case entity.FieldTypeInt8Vector:
// For int8 vector, use COSINE metric
// COSINE distance: range is [0, 2], smaller is better, range is [range_filter, radius]
searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}).
WithANNSField("vector").WithSearchParam("radius", "2").WithSearchParam("range_filter", "0").WithOutputFields("vector"))
default:
// For dense vectors (FloatVector, Float16Vector, BFloat16Vector), use L2 metric
// L2 distance: smaller is better, so radius is upper bound, range_filter is lower bound
searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}).
WithANNSField("vector").WithSearchParam("radius", "100").WithSearchParam("range_filter", "0").WithOutputFields("vector"))
}
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
// Verify all results have valid vectors (not null)
searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
require.Greater(t, len(searchIDs), 0, "range search should return results")
for _, id := range searchIDs {
_, ok := testData.PkToVecIdx[id]
require.True(t, ok, "range search result pk %d should have valid vector", id)
}
// Verify scores are within range based on metric type
scores := searchRes[0].Scores
for i, score := range scores {
switch vt.FieldType {
case entity.FieldTypeSparseVector:
// IP metric: higher is better, range is [radius, range_filter] = [0, 100]
require.GreaterOrEqual(t, score, float32(0), "sparse vector score should be >= radius(0), got %f for pk %d", score, searchIDs[i])
require.LessOrEqual(t, score, float32(100), "sparse vector score should be <= range_filter(100), got %f for pk %d", score, searchIDs[i])
case entity.FieldTypeBinaryVector:
// Hamming distance: range is [range_filter, radius] = [0, 128]
require.GreaterOrEqual(t, score, float32(0), "Hamming score should be >= range_filter(0), got %f for pk %d", score, searchIDs[i])
require.LessOrEqual(t, score, float32(128), "Hamming score should be <= radius(128), got %f for pk %d", score, searchIDs[i])
case entity.FieldTypeInt8Vector:
// COSINE distance: range is [range_filter, radius] = [0, 2]
require.GreaterOrEqual(t, score, float32(0), "COSINE score should be >= range_filter(0), got %f for pk %d", score, searchIDs[i])
require.LessOrEqual(t, score, float32(2), "COSINE score should be <= radius(2), got %f for pk %d", score, searchIDs[i])
default:
// L2 metric: lower is better, range is [range_filter, radius] = [0, 100]
require.GreaterOrEqual(t, score, float32(0), "L2 score should be >= range_filter(0), got %f for pk %d", score, searchIDs[i])
require.LessOrEqual(t, score, float32(100), "L2 score should be <= radius(100), got %f for pk %d", score, searchIDs[i])
}
}
}
// clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
// index building on both SegmentGrowingImpl and ChunkedSegmentSealedImpl
func TestNullableVectorDifferentIndexTypes(t *testing.T) {
vectorTypes := GetVectorTypes()
nullPercents := GetNullPercents()
segmentTypes := []string{"growing", "sealed"}
for _, vt := range vectorTypes {
indexConfigs := GetIndexesForVectorType(vt.FieldType)
for _, nullPercent := range nullPercents {
for _, segmentType := range segmentTypes {
// For growing segment, only test once with default index (interim index IVF_FLAT_CC is always used)
// For sealed segment, iterate through all user-specified index types
var testIndexConfigs []IndexConfig
if segmentType == "growing" {
// Only use first (default) index config for growing segment
testIndexConfigs = []IndexConfig{indexConfigs[0]}
} else {
// Test all index types for sealed segment
testIndexConfigs = indexConfigs
}
for _, idxCfg := range testIndexConfigs {
testName := fmt.Sprintf("%s_%s_%d%%null_%s", vt.Name, idxCfg.Name, nullPercent, segmentType)
idxCfgCopy := idxCfg // capture loop variable
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*10)
mc := hp.CreateDefaultMilvusClient(ctx, t)
// Create collection with nullable vector
collName := common.GenRandomString("nullable_vec_large", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true)
if vt.FieldType != entity.FieldTypeSparseVector {
vecField = vecField.WithDim(common.DefaultDim)
}
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
nb := 10000
validData := make([]bool, nb)
validCount := 0
for i := range nb {
validData[i] = (i % 100) >= nullPercent
if validData[i] {
validCount++
}
}
pkToVecIdx := make(map[int64]int)
vecIdx := 0
for i := range nb {
if validData[i] {
pkToVecIdx[int64(i)] = vecIdx
vecIdx++
}
}
// Generate pk column
pkData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
// Generate vector column based on type
var vecColumn column.Column
var searchVec entity.Vector
var originalVectors interface{}
switch vt.FieldType {
case entity.FieldTypeFloatVector:
vectors := make([][]float32, validCount)
for i := range validCount {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32(i*common.DefaultDim+j) / float32(validCount*common.DefaultDim)
}
vectors[i] = vec
}
vecColumn, err = column.NewNullableColumnFloatVector("vector", common.DefaultDim, vectors, validData)
searchVec = entity.FloatVector(vectors[0])
originalVectors = vectors
case entity.FieldTypeBinaryVector:
vectors := make([][]byte, validCount)
byteDim := common.DefaultDim / 8
for i := range validCount {
vec := make([]byte, byteDim)
for j := range byteDim {
vec[j] = byte((i + j) % 256)
}
vectors[i] = vec
}
vecColumn, err = column.NewNullableColumnBinaryVector("vector", common.DefaultDim, vectors, validData)
searchVec = entity.BinaryVector(vectors[0])
originalVectors = vectors
case entity.FieldTypeFloat16Vector:
vectors := make([][]byte, validCount)
for i := range validCount {
vectors[i] = common.GenFloat16Vector(common.DefaultDim)
}
vecColumn, err = column.NewNullableColumnFloat16Vector("vector", common.DefaultDim, vectors, validData)
searchVec = entity.Float16Vector(vectors[0])
originalVectors = vectors
case entity.FieldTypeBFloat16Vector:
vectors := make([][]byte, validCount)
for i := range validCount {
vectors[i] = common.GenBFloat16Vector(common.DefaultDim)
}
vecColumn, err = column.NewNullableColumnBFloat16Vector("vector", common.DefaultDim, vectors, validData)
searchVec = entity.BFloat16Vector(vectors[0])
originalVectors = vectors
case entity.FieldTypeInt8Vector:
vectors := make([][]int8, validCount)
for i := range validCount {
vec := make([]int8, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = int8((i + j) % 127)
}
vectors[i] = vec
}
vecColumn, err = column.NewNullableColumnInt8Vector("vector", common.DefaultDim, vectors, validData)
searchVec = entity.Int8Vector(vectors[0])
originalVectors = vectors
case entity.FieldTypeSparseVector:
vectors := make([]entity.SparseEmbedding, validCount)
for i := range validCount {
positions := []uint32{0, uint32(i%1000 + 1), uint32(i%10000 + 1000)}
values := []float32{1.0, float32(i+1) / 1000.0, 0.1}
vectors[i], err = entity.NewSliceSparseEmbedding(positions, values)
common.CheckErr(t, err, true)
}
vecColumn, err = column.NewNullableColumnSparseFloatVector("vector", vectors, validData)
searchVec = vectors[0]
originalVectors = vectors
}
common.CheckErr(t, err, true)
// Insert data
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, vecColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
// For sealed segment, flush before creating index to convert growing to sealed
if segmentType == "sealed" {
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
}
// Create index using the config for this test iteration
vecIndex := CreateIndexFromConfig("vector", idxCfgCopy)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
// Load collection - specify load fields to potentially skip loading vector raw data
// When vector has index and is specified in LoadFields, system may use index instead of field data
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName).
WithLoadFields(common.DefaultInt64FieldName, "vector")) // Load pk and vector (via index)
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
// Search
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).
WithOutputFields("*").
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
require.GreaterOrEqual(t, searchRes[0].ResultCount, 1)
// Verify search results
VerifyNullableVectorData(t, vt, searchRes[0], pkToVecIdx, originalVectors, "search")
// Query to count rows
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter(fmt.Sprintf("%s >= 0", common.DefaultInt64FieldName)).
WithOutputFields("count(*)"))
common.CheckErr(t, err, true)
countCol := queryRes.GetColumn("count(*)")
count, _ := countCol.GetAsInt64(0)
require.EqualValues(t, nb, count)
// Query with vector output to verify data
queryVecRes, err := mc.Query(ctx, client.NewQueryOption(collName).
WithFilter(fmt.Sprintf("%s < 100", common.DefaultInt64FieldName)).
WithOutputFields("*"))
common.CheckErr(t, err, true)
// Verify query results
VerifyNullableVectorData(t, vt, queryVecRes, pkToVecIdx, originalVectors, "query")
// Clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
}
}
}
func TestNullableVectorGroupBy(t *testing.T) {
groupByVectorTypes := []NullableVectorType{
{"FloatVector", entity.FieldTypeFloatVector},
{"Float16Vector", entity.FieldTypeFloat16Vector},
{"BFloat16Vector", entity.FieldTypeBFloat16Vector},
}
nullPercents := GetNullPercents()
for _, vt := range groupByVectorTypes {
for _, nullPercent := range nullPercents {
testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent)
t.Run(testName, func(t *testing.T) {
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
mc := hp.CreateDefaultMilvusClient(ctx, t)
collName := common.GenRandomString("nullable_vec_groupby", 5)
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true).WithDim(common.DefaultDim)
groupField := entity.NewField().WithName("group_id").WithDataType(entity.FieldTypeInt64)
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField).WithField(groupField)
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
nb := 500
numGroups := 50
rowsPerGroup := nb / numGroups
validData := make([]bool, nb)
validCount := 0
for i := range nb {
validData[i] = (i % 100) >= nullPercent
if validData[i] {
validCount++
}
}
pkData := make([]int64, nb)
groupData := make([]int64, nb)
for i := range nb {
pkData[i] = int64(i)
groupData[i] = int64(i / rowsPerGroup) // 0-9 -> group 0, 10-19 -> group 1, etc.
}
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
groupColumn := column.NewColumnInt64("group_id", groupData)
var vecColumn column.Column
var searchVec entity.Vector
switch vt.FieldType {
case entity.FieldTypeFloatVector:
vectors := make([][]float32, validCount)
for i := range validCount {
vec := make([]float32, common.DefaultDim)
for j := range common.DefaultDim {
vec[j] = float32(i*common.DefaultDim+j) / 10000.0
}
vectors[i] = vec
}
vecColumn, err = column.NewNullableColumnFloatVector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.FloatVector(vectors[0])
}
case entity.FieldTypeFloat16Vector:
vectors := make([][]byte, validCount)
for i := range validCount {
vectors[i] = common.GenFloat16Vector(common.DefaultDim)
}
vecColumn, err = column.NewNullableColumnFloat16Vector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.Float16Vector(vectors[0])
}
case entity.FieldTypeBFloat16Vector:
vectors := make([][]byte, validCount)
for i := range validCount {
vectors[i] = common.GenBFloat16Vector(common.DefaultDim)
}
vecColumn, err = column.NewNullableColumnBFloat16Vector("vector", common.DefaultDim, vectors, validData)
if validCount > 0 {
searchVec = entity.BFloat16Vector(vectors[0])
}
}
common.CheckErr(t, err, true)
// Insert
insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, vecColumn, groupColumn))
common.CheckErr(t, err, true)
require.EqualValues(t, nb, insertRes.InsertCount)
// Flush
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
common.CheckErr(t, err, true)
err = flushTask.Await(ctx)
common.CheckErr(t, err, true)
if validCount > 0 {
vecIndex := index.NewFlatIndex(entity.L2)
indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex))
common.CheckErr(t, err, true)
err = indexTask.Await(ctx)
common.CheckErr(t, err, true)
scalarIndexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "group_id", index.NewAutoIndex(entity.L2)))
common.CheckErr(t, err, true)
err = scalarIndexTask.Await(ctx)
common.CheckErr(t, err, true)
// Load
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
common.CheckErr(t, err, true)
err = loadTask.Await(ctx)
common.CheckErr(t, err, true)
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).
WithANNSField("vector").
WithGroupByField("group_id").
WithOutputFields(common.DefaultInt64FieldName, "group_id"))
common.CheckErr(t, err, true)
require.EqualValues(t, 1, len(searchRes))
// 1. Result count should be <= limit (number of unique groups)
// 2. Each result should have a unique group_id
// 3. All returned PKs should have valid vectors (not null)
resultCount := searchRes[0].ResultCount
require.LessOrEqual(t, resultCount, 10, "result count should be <= limit")
// Check unique group_ids
seenGroups := make(map[int64]bool)
for i := 0; i < resultCount; i++ {
groupByValue, err := searchRes[0].GroupByValue.Get(i)
require.NoError(t, err)
groupID := groupByValue.(int64)
require.False(t, seenGroups[groupID], "group_id should be unique in GroupBy results")
seenGroups[groupID] = true
// Verify the returned PK has a valid vector
pkValue, _ := searchRes[0].IDs.GetAsInt64(i)
require.True(t, validData[pkValue], "returned pk %d should have valid vector", pkValue)
}
}
// Clean up
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
common.CheckErr(t, err, true)
})
}
}
}