mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
fix: Prevent panic in upsert due to missing nullable fields [Proxy] (#44070)
issue: #43980 Fixes a panic that occurred when a partial update was converted to an insert due to a non-existent primary key. The panic was caused by missing nullable fields that were not provided in the original partial update request. The upsert pre-execution logic is refactored to handle this correctly: - Explicitly splits upsert data into 'insert' and 'update' batches. - Automatically generates data for missing nullable or default-value fields during inserts, preventing the panic. - Enhances `typeutil.UpdateFieldData` to support different source and destination indexes for flexible data merging. - Adds comprehensive unit tests for mixed upsert, pure insert, and pure update scenarios. --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
4376876f90
commit
16af4e230a
@ -21,6 +21,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@ -244,14 +245,14 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldSchema.Name))
|
||||
}
|
||||
|
||||
oldIDs, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
||||
upsertIDs, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
||||
if err != nil {
|
||||
log.Warn("parse primary field data to IDs failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
oldIDSize := typeutil.GetSizeOfIDs(oldIDs)
|
||||
if oldIDSize == 0 {
|
||||
upsertIDSize := typeutil.GetSizeOfIDs(upsertIDs)
|
||||
if upsertIDSize == 0 {
|
||||
it.deletePKs = &schemapb.IDs{}
|
||||
it.insertFieldData = it.req.GetFieldsData()
|
||||
log.Info("old records not found, just do insert")
|
||||
@ -260,7 +261,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
|
||||
tr := timerecord.NewTimeRecorder("Proxy-Upsert-retrieveByPKs")
|
||||
// retrieve by primary key to get original field data
|
||||
resp, err := retrieveByPKs(ctx, it, oldIDs, []string{"*"})
|
||||
resp, err := retrieveByPKs(ctx, it, upsertIDs, []string{"*"})
|
||||
if err != nil {
|
||||
log.Info("retrieve by primary key failed", zap.Error(err))
|
||||
return err
|
||||
@ -285,23 +286,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
zap.Int("resultNum", typeutil.GetSizeOfIDs(existIDs)),
|
||||
zap.Int64("latency", tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
// check whether the primary key is exist in query result
|
||||
idsChecker, err := typeutil.NewIDsChecker(existIDs)
|
||||
if err != nil {
|
||||
log.Info("create primary key checker failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Build mapping from existing primary keys to their positions in query result
|
||||
// This ensures we can correctly locate data even if query results are not in the same order as request
|
||||
existIDsLen := typeutil.GetSizeOfIDs(existIDs)
|
||||
existPKToIndex := make(map[interface{}]int, existIDsLen)
|
||||
for j := 0; j < existIDsLen; j++ {
|
||||
pk := typeutil.GetPK(existIDs, int64(j))
|
||||
existPKToIndex[pk] = j
|
||||
}
|
||||
|
||||
// set field id for user passed field data
|
||||
// set field id for user passed field data, prepare for merge logic
|
||||
upsertFieldData := it.upsertMsg.InsertMsg.GetFieldsData()
|
||||
if len(upsertFieldData) == 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("upsert field data is empty")
|
||||
@ -320,72 +305,121 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
fieldData.FieldName = fieldName
|
||||
}
|
||||
|
||||
lackOfFieldErr := LackOfFieldsDataBySchema(it.schema.CollectionSchema, it.upsertMsg.InsertMsg.GetFieldsData(), false, true)
|
||||
it.deletePKs = &schemapb.IDs{}
|
||||
it.insertFieldData = make([]*schemapb.FieldData, len(existFieldData))
|
||||
for i := 0; i < oldIDSize; i++ {
|
||||
exist, err := idsChecker.Contains(oldIDs, i)
|
||||
// Note: the most difficult part is to handle the merge progress of upsert and query result
|
||||
// we need to enable merge logic on different length between upsertFieldData and it.insertFieldData
|
||||
insertIdxInUpsert := make([]int, 0)
|
||||
updateIdxInUpsert := make([]int, 0)
|
||||
// 1. split upsert data into insert and update by query result
|
||||
idsChecker, err := typeutil.NewIDsChecker(existIDs)
|
||||
if err != nil {
|
||||
log.Info("create primary key checker failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
for upsertIdx := 0; upsertIdx < upsertIDSize; upsertIdx++ {
|
||||
exist, err := idsChecker.Contains(upsertIDs, upsertIdx)
|
||||
if err != nil {
|
||||
log.Info("check primary key exist in query result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if exist {
|
||||
// treat upsert as update
|
||||
// 1. if pk exist in query result, add it to deletePKs
|
||||
typeutil.AppendIDs(it.deletePKs, oldIDs, i)
|
||||
// 2. construct the field data for update using correct index mapping
|
||||
oldPK := typeutil.GetPK(oldIDs, int64(i))
|
||||
updateIdxInUpsert = append(updateIdxInUpsert, upsertIdx)
|
||||
} else {
|
||||
insertIdxInUpsert = append(insertIdxInUpsert, upsertIdx)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. merge field data on update semantic
|
||||
it.deletePKs = &schemapb.IDs{}
|
||||
it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize))
|
||||
if len(updateIdxInUpsert) > 0 {
|
||||
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
||||
// but query results fields do not set valid data when returning default value fields,
|
||||
// therefore valid data needs to be manually set to true
|
||||
for _, fieldData := range existFieldData {
|
||||
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName())
|
||||
if err != nil {
|
||||
log.Info("get field schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if fieldSchema.GetDefaultValue() != nil {
|
||||
fieldData.ValidData = make([]bool, upsertIDSize)
|
||||
for i := range fieldData.ValidData {
|
||||
fieldData.ValidData[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build mapping from existing primary keys to their positions in query result
|
||||
// This ensures we can correctly locate data even if query results are not in the same order as request
|
||||
existIDsLen := typeutil.GetSizeOfIDs(existIDs)
|
||||
existPKToIndex := make(map[interface{}]int, existIDsLen)
|
||||
for j := 0; j < existIDsLen; j++ {
|
||||
pk := typeutil.GetPK(existIDs, int64(j))
|
||||
existPKToIndex[pk] = j
|
||||
}
|
||||
|
||||
baseIdx := 0
|
||||
for _, idx := range updateIdxInUpsert {
|
||||
typeutil.AppendIDs(it.deletePKs, upsertIDs, idx)
|
||||
oldPK := typeutil.GetPK(upsertIDs, int64(idx))
|
||||
existIndex, ok := existPKToIndex[oldPK]
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
|
||||
}
|
||||
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex))
|
||||
err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(i))
|
||||
err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(baseIdx), int64(idx))
|
||||
baseIdx += 1
|
||||
if err != nil {
|
||||
log.Info("update field data failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// treat upsert as insert
|
||||
if lackOfFieldErr != nil {
|
||||
log.Info("check fields data by schema failed", zap.Error(lackOfFieldErr))
|
||||
return lackOfFieldErr
|
||||
}
|
||||
// use field data from upsert request
|
||||
typeutil.AppendFieldData(it.insertFieldData, upsertFieldData, int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
for _, fieldData := range it.insertFieldData {
|
||||
if fieldData.GetIsDynamic() {
|
||||
continue
|
||||
}
|
||||
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName())
|
||||
if err != nil {
|
||||
log.Info("get field schema failed", zap.Error(err))
|
||||
return err
|
||||
// 3. merge field data on insert semantic
|
||||
if len(insertIdxInUpsert) > 0 {
|
||||
// if necessary field is not exist in upsert request, return error
|
||||
lackOfFieldErr := LackOfFieldsDataBySchema(it.schema.CollectionSchema, it.upsertMsg.InsertMsg.GetFieldsData(), false, true)
|
||||
if lackOfFieldErr != nil {
|
||||
log.Info("check fields data by schema failed", zap.Error(lackOfFieldErr))
|
||||
return lackOfFieldErr
|
||||
}
|
||||
|
||||
// Note: Since protobuf cannot correctly identify null values, zero values + valid data are used to identify null values,
|
||||
// therefore for field data obtained from query results, if the field is nullable, it needs to be set to empty values
|
||||
if fieldSchema.GetNullable() {
|
||||
if getValidNumber(fieldData.GetValidData()) != len(fieldData.GetValidData()) {
|
||||
err := ResetNullFieldData(fieldData, fieldSchema)
|
||||
if err != nil {
|
||||
log.Info("reset null field data failed", zap.Error(err))
|
||||
return err
|
||||
// if the nullable or default value field is not set in upsert request, which means the len(upsertFieldData) < len(it.insertFieldData)
|
||||
// we need to generate the nullable field data before append as insert
|
||||
insertWithNullField := make([]*schemapb.FieldData, 0)
|
||||
upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (string, *schemapb.FieldData) {
|
||||
return field.GetFieldName(), field
|
||||
})
|
||||
for _, fieldSchema := range it.schema.CollectionSchema.Fields {
|
||||
if fieldData, ok := upsertFieldMap[fieldSchema.Name]; !ok {
|
||||
if fieldSchema.GetNullable() || fieldSchema.GetDefaultValue() != nil {
|
||||
fieldData, err := GenNullableFieldData(fieldSchema, upsertIDSize)
|
||||
if err != nil {
|
||||
log.Info("generate nullable field data failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
insertWithNullField = append(insertWithNullField, fieldData)
|
||||
}
|
||||
} else {
|
||||
insertWithNullField = append(insertWithNullField, fieldData)
|
||||
}
|
||||
}
|
||||
for _, idx := range insertIdxInUpsert {
|
||||
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx))
|
||||
}
|
||||
}
|
||||
|
||||
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
||||
// but query results fields do not set valid data when returning default value fields,
|
||||
// therefore valid data needs to be manually set to true
|
||||
if fieldSchema.GetDefaultValue() != nil {
|
||||
fieldData.ValidData = make([]bool, oldIDSize)
|
||||
for i := range fieldData.ValidData {
|
||||
fieldData.ValidData[i] = true
|
||||
// 4. clean field data with valid data after merge upsert and query result
|
||||
for _, fieldData := range it.insertFieldData {
|
||||
// Note: Since protobuf cannot correctly identify null values, zero values + valid data are used to identify null values,
|
||||
// therefore for field data obtained from query results, if the field is nullable, it needs to clean zero values
|
||||
if len(fieldData.GetValidData()) != 0 && getValidNumber(fieldData.GetValidData()) != len(fieldData.GetValidData()) {
|
||||
err := ResetNullFieldData(fieldData)
|
||||
if err != nil {
|
||||
log.Info("reset null field data failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -393,11 +427,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
if !fieldSchema.GetNullable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResetNullFieldData(field *schemapb.FieldData) error {
|
||||
switch field.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
switch sd := field.GetScalars().GetData().(type) {
|
||||
@ -406,7 +436,7 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
if validRowNum == 0 {
|
||||
sd.BoolData.Data = make([]bool, 0)
|
||||
} else {
|
||||
ret := make([]bool, validRowNum)
|
||||
ret := make([]bool, 0, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.BoolData.Data[i])
|
||||
@ -420,7 +450,7 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
if validRowNum == 0 {
|
||||
sd.IntData.Data = make([]int32, 0)
|
||||
} else {
|
||||
ret := make([]int32, validRowNum)
|
||||
ret := make([]int32, 0, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.IntData.Data[i])
|
||||
@ -434,7 +464,7 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
if validRowNum == 0 {
|
||||
sd.LongData.Data = make([]int64, 0)
|
||||
} else {
|
||||
ret := make([]int64, validRowNum)
|
||||
ret := make([]int64, 0, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.LongData.Data[i])
|
||||
@ -448,7 +478,7 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
if validRowNum == 0 {
|
||||
sd.FloatData.Data = make([]float32, 0)
|
||||
} else {
|
||||
ret := make([]float32, validRowNum)
|
||||
ret := make([]float32, 0, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.FloatData.Data[i])
|
||||
@ -462,7 +492,7 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
if validRowNum == 0 {
|
||||
sd.DoubleData.Data = make([]float64, 0)
|
||||
} else {
|
||||
ret := make([]float64, validRowNum)
|
||||
ret := make([]float64, 0, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.DoubleData.Data[i])
|
||||
@ -476,7 +506,7 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
if validRowNum == 0 {
|
||||
sd.StringData.Data = make([]string, 0)
|
||||
} else {
|
||||
ret := make([]string, validRowNum)
|
||||
ret := make([]string, 0, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.StringData.Data[i])
|
||||
@ -490,7 +520,7 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
if validRowNum == 0 {
|
||||
sd.JsonData.Data = make([][]byte, 0)
|
||||
} else {
|
||||
ret := make([][]byte, validRowNum)
|
||||
ret := make([][]byte, 0, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.JsonData.Data[i])
|
||||
@ -510,6 +540,139 @@ func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSc
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schemapb.FieldData, error) {
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
return &schemapb.FieldData{
|
||||
FieldId: field.FieldID,
|
||||
FieldName: field.Name,
|
||||
Type: field.DataType,
|
||||
IsDynamic: field.IsDynamic,
|
||||
ValidData: make([]bool, upsertIDSize),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: make([]bool, upsertIDSize),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case schemapb.DataType_Int32:
|
||||
return &schemapb.FieldData{
|
||||
FieldId: field.FieldID,
|
||||
FieldName: field.Name,
|
||||
Type: field.DataType,
|
||||
IsDynamic: field.IsDynamic,
|
||||
ValidData: make([]bool, upsertIDSize),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{
|
||||
Data: make([]int32, upsertIDSize),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case schemapb.DataType_Int64:
|
||||
return &schemapb.FieldData{
|
||||
FieldId: field.FieldID,
|
||||
FieldName: field.Name,
|
||||
Type: field.DataType,
|
||||
IsDynamic: field.IsDynamic,
|
||||
ValidData: make([]bool, upsertIDSize),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: make([]int64, upsertIDSize),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case schemapb.DataType_Float:
|
||||
return &schemapb.FieldData{
|
||||
FieldId: field.FieldID,
|
||||
FieldName: field.Name,
|
||||
Type: field.DataType,
|
||||
IsDynamic: field.IsDynamic,
|
||||
ValidData: make([]bool, upsertIDSize),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: make([]float32, upsertIDSize),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case schemapb.DataType_Double:
|
||||
return &schemapb.FieldData{
|
||||
FieldId: field.FieldID,
|
||||
FieldName: field.Name,
|
||||
Type: field.DataType,
|
||||
IsDynamic: field.IsDynamic,
|
||||
ValidData: make([]bool, upsertIDSize),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_DoubleData{
|
||||
DoubleData: &schemapb.DoubleArray{
|
||||
Data: make([]float64, upsertIDSize),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case schemapb.DataType_VarChar:
|
||||
return &schemapb.FieldData{
|
||||
FieldId: field.FieldID,
|
||||
FieldName: field.Name,
|
||||
Type: field.DataType,
|
||||
IsDynamic: field.IsDynamic,
|
||||
ValidData: make([]bool, upsertIDSize),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: make([]string, upsertIDSize),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
case schemapb.DataType_JSON:
|
||||
return &schemapb.FieldData{
|
||||
FieldId: field.FieldID,
|
||||
FieldName: field.Name,
|
||||
Type: field.DataType,
|
||||
IsDynamic: field.IsDynamic,
|
||||
ValidData: make([]bool, upsertIDSize),
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: make([][]byte, upsertIDSize),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined scalar data type:%s", field.DataType.String()))
|
||||
}
|
||||
}
|
||||
|
||||
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute")
|
||||
defer sp.End()
|
||||
|
||||
@ -1082,3 +1082,273 @@ func TestUpdateTask_PreExecute_QueryPreExecuteError(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "query pre-execute failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpsertTask_queryPreExecute_MixLogic(t *testing.T) {
|
||||
// Schema for the test collection
|
||||
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
||||
Name: "test_merge_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
||||
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
||||
},
|
||||
})
|
||||
|
||||
// Upsert IDs: 1 (update), 2 (update), 3 (insert)
|
||||
upsertData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{100, 200, 300}}}}},
|
||||
},
|
||||
}
|
||||
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
||||
|
||||
// Query result for existing PKs: 1, 2
|
||||
mockQueryResult := &milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{10, 20}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}}}}},
|
||||
ValidData: []bool{true, true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
task := &upsertTask{
|
||||
ctx: context.Background(),
|
||||
schema: schema,
|
||||
req: &milvuspb.UpsertRequest{
|
||||
FieldsData: upsertData,
|
||||
NumRows: uint32(numRows),
|
||||
},
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
FieldsData: upsertData,
|
||||
NumRows: numRows,
|
||||
},
|
||||
},
|
||||
},
|
||||
node: &Proxy{},
|
||||
}
|
||||
|
||||
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build()
|
||||
defer mockRetrieve.UnPatch()
|
||||
|
||||
err := task.queryPreExecute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify delete PKs
|
||||
deletePks := task.deletePKs.GetIntId().GetData()
|
||||
assert.ElementsMatch(t, []int64{1, 2}, deletePks)
|
||||
|
||||
// Verify merged insert data
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
||||
assert.NoError(t, err)
|
||||
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
||||
assert.NoError(t, err)
|
||||
ids, err := parsePrimaryFieldData2IDs(idField)
|
||||
assert.NoError(t, err)
|
||||
insertPKs := ids.GetIntId().GetData()
|
||||
assert.Equal(t, []int64{1, 2, 3}, insertPKs)
|
||||
|
||||
var valueField *schemapb.FieldData
|
||||
for _, f := range task.insertFieldData {
|
||||
if f.GetFieldName() == "value" {
|
||||
valueField = f
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, valueField)
|
||||
assert.Equal(t, []int32{100, 200, 300}, valueField.GetScalars().GetIntData().GetData())
|
||||
}
|
||||
|
||||
func TestUpsertTask_queryPreExecute_PureInsert(t *testing.T) {
|
||||
// Schema for the test collection
|
||||
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
||||
Name: "test_merge_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
||||
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
||||
},
|
||||
})
|
||||
|
||||
// Upsert IDs: 4, 5
|
||||
upsertData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{4, 5}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{400, 500}}}}},
|
||||
},
|
||||
}
|
||||
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
||||
|
||||
// Query result is empty, but schema is preserved
|
||||
mockQueryResult := &milvuspb.QueryResults{Status: merr.Success(), FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{}}}}},
|
||||
},
|
||||
}}
|
||||
|
||||
task := &upsertTask{
|
||||
ctx: context.Background(),
|
||||
schema: schema,
|
||||
req: &milvuspb.UpsertRequest{
|
||||
FieldsData: upsertData,
|
||||
NumRows: uint32(numRows),
|
||||
},
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
FieldsData: upsertData,
|
||||
NumRows: numRows,
|
||||
},
|
||||
},
|
||||
},
|
||||
node: &Proxy{},
|
||||
}
|
||||
|
||||
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build()
|
||||
defer mockRetrieve.UnPatch()
|
||||
|
||||
err := task.queryPreExecute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify delete PKs
|
||||
deletePks := task.deletePKs.GetIntId().GetData()
|
||||
assert.Empty(t, deletePks)
|
||||
|
||||
// Verify merged insert data
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
||||
assert.NoError(t, err)
|
||||
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
||||
assert.NoError(t, err)
|
||||
ids, err := parsePrimaryFieldData2IDs(idField)
|
||||
assert.NoError(t, err)
|
||||
insertPKs := ids.GetIntId().GetData()
|
||||
assert.Equal(t, []int64{4, 5}, insertPKs)
|
||||
|
||||
var valueField *schemapb.FieldData
|
||||
for _, f := range task.insertFieldData {
|
||||
if f.GetFieldName() == "value" {
|
||||
valueField = f
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, valueField)
|
||||
assert.Equal(t, []int32{400, 500}, valueField.GetScalars().GetIntData().GetData())
|
||||
}
|
||||
|
||||
func TestUpsertTask_queryPreExecute_PureUpdate(t *testing.T) {
|
||||
// Schema for the test collection
|
||||
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
||||
Name: "test_merge_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
||||
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
||||
},
|
||||
})
|
||||
|
||||
// Upsert IDs: 6, 7
|
||||
upsertData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{600, 700}}}}},
|
||||
},
|
||||
}
|
||||
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
||||
|
||||
// Query result for existing PKs: 6, 7
|
||||
mockQueryResult := &milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
|
||||
},
|
||||
{
|
||||
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
||||
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{60, 70}}}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
task := &upsertTask{
|
||||
ctx: context.Background(),
|
||||
schema: schema,
|
||||
req: &milvuspb.UpsertRequest{
|
||||
FieldsData: upsertData,
|
||||
NumRows: uint32(numRows),
|
||||
},
|
||||
upsertMsg: &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
FieldsData: upsertData,
|
||||
NumRows: numRows,
|
||||
},
|
||||
},
|
||||
},
|
||||
node: &Proxy{},
|
||||
}
|
||||
|
||||
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build()
|
||||
defer mockRetrieve.UnPatch()
|
||||
|
||||
err := task.queryPreExecute(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify delete PKs
|
||||
deletePks := task.deletePKs.GetIntId().GetData()
|
||||
assert.ElementsMatch(t, []int64{6, 7}, deletePks)
|
||||
|
||||
// Verify merged insert data
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
||||
assert.NoError(t, err)
|
||||
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
||||
assert.NoError(t, err)
|
||||
ids, err := parsePrimaryFieldData2IDs(idField)
|
||||
assert.NoError(t, err)
|
||||
insertPKs := ids.GetIntId().GetData()
|
||||
assert.Equal(t, []int64{6, 7}, insertPKs)
|
||||
|
||||
var valueField *schemapb.FieldData
|
||||
for _, f := range task.insertFieldData {
|
||||
if f.GetFieldName() == "value" {
|
||||
valueField = f
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, valueField)
|
||||
assert.Equal(t, []int32{600, 700}, valueField.GetScalars().GetIntData().GetData())
|
||||
}
|
||||
|
||||
@ -1952,6 +1952,10 @@ func LackOfFieldsDataBySchema(schema *schemapb.CollectionSchema, fieldsData []*s
|
||||
continue
|
||||
}
|
||||
|
||||
if fieldSchema.GetNullable() || fieldSchema.GetDefaultValue() != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := dataNameMap[fieldSchema.GetName()]; !ok {
|
||||
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && skipPkFieldCheck) ||
|
||||
IsBM25FunctionOutputField(fieldSchema, schema) ||
|
||||
|
||||
@ -4466,3 +4466,103 @@ func Test_reconstructStructFieldDataCommon(t *testing.T) {
|
||||
assert.True(t, foundStruct2, "Should find struct2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLackOfFieldsDataBySchema(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk_field", IsPrimaryKey: true, DataType: schemapb.DataType_Int64, AutoID: true},
|
||||
{FieldID: 101, Name: "required_field", DataType: schemapb.DataType_Float},
|
||||
{FieldID: 102, Name: "nullable_field", DataType: schemapb.DataType_VarChar, Nullable: true},
|
||||
{FieldID: 103, Name: "default_value_field", DataType: schemapb.DataType_JSON, DefaultValue: &schemapb.ValueField{Data: &schemapb.ValueField_StringData{StringData: "{}"}}},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldsData []*schemapb.FieldData
|
||||
skipPkFieldCheck bool
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "all required fields present",
|
||||
fieldsData: []*schemapb.FieldData{
|
||||
{FieldName: "pk_field"},
|
||||
{FieldName: "required_field"},
|
||||
{FieldName: "nullable_field"},
|
||||
{FieldName: "default_value_field"},
|
||||
},
|
||||
skipPkFieldCheck: false,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing required field",
|
||||
fieldsData: []*schemapb.FieldData{
|
||||
{FieldName: "pk_field"},
|
||||
},
|
||||
skipPkFieldCheck: false,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing nullable field is ok",
|
||||
fieldsData: []*schemapb.FieldData{
|
||||
{FieldName: "pk_field"},
|
||||
{FieldName: "required_field"},
|
||||
{FieldName: "default_value_field"},
|
||||
},
|
||||
skipPkFieldCheck: false,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing default value field is ok",
|
||||
fieldsData: []*schemapb.FieldData{
|
||||
{FieldName: "pk_field"},
|
||||
{FieldName: "required_field"},
|
||||
{FieldName: "nullable_field"},
|
||||
},
|
||||
skipPkFieldCheck: false,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing nullable and default value field is ok",
|
||||
fieldsData: []*schemapb.FieldData{
|
||||
{FieldName: "pk_field"},
|
||||
{FieldName: "required_field"},
|
||||
},
|
||||
skipPkFieldCheck: false,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty fields data",
|
||||
fieldsData: []*schemapb.FieldData{},
|
||||
skipPkFieldCheck: false,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "skip pk check",
|
||||
fieldsData: []*schemapb.FieldData{
|
||||
{FieldName: "required_field"},
|
||||
},
|
||||
skipPkFieldCheck: true,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing pk without skip",
|
||||
fieldsData: []*schemapb.FieldData{
|
||||
{FieldName: "required_field"},
|
||||
},
|
||||
skipPkFieldCheck: false,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := LackOfFieldsDataBySchema(schema, tt.fieldsData, tt.skipPkFieldCheck, false)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1071,7 +1071,7 @@ func DeleteFieldData(dst []*schemapb.FieldData) {
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error {
|
||||
func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int64) error {
|
||||
// Create a map for quick lookup of update fields by field ID
|
||||
updateFieldMap := make(map[string]*schemapb.FieldData)
|
||||
for _, fieldData := range update {
|
||||
@ -1085,16 +1085,26 @@ func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error {
|
||||
continue
|
||||
}
|
||||
|
||||
updateFieldIdx := updateIdx
|
||||
// Update ValidData if present
|
||||
if len(updateFieldData.GetValidData()) != 0 {
|
||||
if len(baseFieldData.GetValidData()) != 0 {
|
||||
baseFieldData.ValidData[idx] = updateFieldData.ValidData[idx]
|
||||
baseFieldData.ValidData[baseIdx] = updateFieldData.ValidData[updateFieldIdx]
|
||||
}
|
||||
|
||||
// update field data to null,only modify valid data
|
||||
if !updateFieldData.ValidData[idx] {
|
||||
if !updateFieldData.ValidData[updateFieldIdx] {
|
||||
continue
|
||||
}
|
||||
|
||||
// for nullable field data, such as data=[1,1], valid_data=[true, false, true]
|
||||
// should update the updateFieldIdx to the expected valid index
|
||||
updateFieldIdx = 0
|
||||
for _, validData := range updateFieldData.GetValidData()[:updateIdx] {
|
||||
if validData {
|
||||
updateFieldIdx += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update field data based on type
|
||||
@ -1107,53 +1117,53 @@ func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error {
|
||||
switch baseScalar.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
updateData := updateScalar.GetBoolData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetBoolData().Data[idx] = updateData.Data[idx]
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
baseScalar.GetBoolData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
case *schemapb.ScalarField_IntData:
|
||||
updateData := updateScalar.GetIntData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetIntData().Data[idx] = updateData.Data[idx]
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
baseScalar.GetIntData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
case *schemapb.ScalarField_LongData:
|
||||
updateData := updateScalar.GetLongData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetLongData().Data[idx] = updateData.Data[idx]
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
baseScalar.GetLongData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
updateData := updateScalar.GetFloatData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetFloatData().Data[idx] = updateData.Data[idx]
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
baseScalar.GetFloatData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
updateData := updateScalar.GetDoubleData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetDoubleData().Data[idx] = updateData.Data[idx]
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
baseScalar.GetDoubleData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
case *schemapb.ScalarField_StringData:
|
||||
updateData := updateScalar.GetStringData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetStringData().Data[idx] = updateData.Data[idx]
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
baseScalar.GetStringData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
case *schemapb.ScalarField_ArrayData:
|
||||
updateData := updateScalar.GetArrayData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetArrayData().Data[idx] = updateData.Data[idx]
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
baseScalar.GetArrayData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
case *schemapb.ScalarField_JsonData:
|
||||
updateData := updateScalar.GetJsonData()
|
||||
baseData := baseScalar.GetJsonData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
||||
if baseFieldData.GetIsDynamic() {
|
||||
// dynamic field is a json with only 1 level nested struct,
|
||||
// so we need to unmarshal and iterate updateData's key value, and update the baseData's key value
|
||||
var baseMap map[string]interface{}
|
||||
var updateMap map[string]interface{}
|
||||
// unmarshal base and update
|
||||
if err := json.Unmarshal(baseData.Data[idx], &baseMap); err != nil {
|
||||
if err := json.Unmarshal(baseData.Data[baseIdx], &baseMap); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal base json: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(updateData.Data[idx], &updateMap); err != nil {
|
||||
if err := json.Unmarshal(updateData.Data[updateFieldIdx], &updateMap); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal update json: %v", err)
|
||||
}
|
||||
// merge
|
||||
@ -1165,9 +1175,9 @@ func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal merged json: %v", err)
|
||||
}
|
||||
baseScalar.GetJsonData().Data[idx] = newJSON
|
||||
baseScalar.GetJsonData().Data[baseIdx] = newJSON
|
||||
} else {
|
||||
baseScalar.GetJsonData().Data[idx] = updateData.Data[idx]
|
||||
baseScalar.GetJsonData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
||||
}
|
||||
}
|
||||
default:
|
||||
@ -1186,48 +1196,56 @@ func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error {
|
||||
updateData := updateVector.GetBinaryVector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetBinaryVector()
|
||||
startIdx := idx * (dim / 8)
|
||||
endIdx := (idx + 1) * (dim / 8)
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
baseStartIdx := updateFieldIdx * (dim / 8)
|
||||
baseEndIdx := (updateFieldIdx + 1) * (dim / 8)
|
||||
updateStartIdx := updateFieldIdx * (dim / 8)
|
||||
updateEndIdx := (updateFieldIdx + 1) * (dim / 8)
|
||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
updateData := updateVector.GetFloatVector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetFloatVector()
|
||||
startIdx := idx * dim
|
||||
endIdx := (idx + 1) * dim
|
||||
if int(endIdx) <= len(updateData.Data) && int(endIdx) <= len(baseData.Data) {
|
||||
copy(baseData.Data[startIdx:endIdx], updateData.Data[startIdx:endIdx])
|
||||
baseStartIdx := updateFieldIdx * dim
|
||||
baseEndIdx := (updateFieldIdx + 1) * dim
|
||||
updateStartIdx := updateFieldIdx * dim
|
||||
updateEndIdx := (updateFieldIdx + 1) * dim
|
||||
if int(updateEndIdx) <= len(updateData.Data) && int(baseEndIdx) <= len(baseData.Data) {
|
||||
copy(baseData.Data[baseStartIdx:baseEndIdx], updateData.Data[updateStartIdx:updateEndIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_Float16Vector:
|
||||
updateData := updateVector.GetFloat16Vector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetFloat16Vector()
|
||||
startIdx := idx * (dim * 2)
|
||||
endIdx := (idx + 1) * (dim * 2)
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
baseStartIdx := updateFieldIdx * (dim * 2)
|
||||
baseEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
||||
updateStartIdx := updateFieldIdx * (dim * 2)
|
||||
updateEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_Bfloat16Vector:
|
||||
updateData := updateVector.GetBfloat16Vector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetBfloat16Vector()
|
||||
startIdx := idx * (dim * 2)
|
||||
endIdx := (idx + 1) * (dim * 2)
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
baseStartIdx := updateFieldIdx * (dim * 2)
|
||||
baseEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
||||
updateStartIdx := updateFieldIdx * (dim * 2)
|
||||
updateEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_SparseFloatVector:
|
||||
updateData := updateVector.GetSparseFloatVector()
|
||||
if updateData != nil && int(idx) < len(updateData.Contents) {
|
||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Contents) {
|
||||
baseData := baseVector.GetSparseFloatVector()
|
||||
if int(idx) < len(baseData.Contents) {
|
||||
baseData.Contents[idx] = updateData.Contents[idx]
|
||||
if int(updateFieldIdx) < len(baseData.Contents) {
|
||||
baseData.Contents[updateFieldIdx] = updateData.Contents[updateFieldIdx]
|
||||
// Update dimension if necessary
|
||||
if updateData.Dim > baseData.Dim {
|
||||
baseData.Dim = updateData.Dim
|
||||
@ -1238,10 +1256,12 @@ func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error {
|
||||
updateData := updateVector.GetInt8Vector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetInt8Vector()
|
||||
startIdx := idx * dim
|
||||
endIdx := (idx + 1) * dim
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
baseStartIdx := updateFieldIdx * dim
|
||||
baseEndIdx := (updateFieldIdx + 1) * dim
|
||||
updateStartIdx := updateFieldIdx * dim
|
||||
updateEndIdx := (updateFieldIdx + 1) * dim
|
||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
||||
@ -3078,7 +3078,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update index 1
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
err := UpdateFieldData(baseData, updateData, 1, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check results
|
||||
@ -3141,7 +3141,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update index 1 (second vector)
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
err := UpdateFieldData(baseData, updateData, 1, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check results
|
||||
@ -3180,7 +3180,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
updateData := []*schemapb.FieldData{}
|
||||
|
||||
// Update should succeed but change nothing
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
err := UpdateFieldData(baseData, updateData, 1, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Data should remain unchanged
|
||||
@ -3193,7 +3193,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: Int64FieldName,
|
||||
FieldId: Int64FieldID,
|
||||
ValidData: []bool{true, true, false, true},
|
||||
ValidData: []bool{true, true, true, true},
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
@ -3216,7 +3216,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{10, 20, 30, 40},
|
||||
Data: []int64{30},
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -3225,9 +3225,9 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update index 1
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
err := UpdateFieldData(baseData, updateData, 1, 1)
|
||||
require.NoError(t, err)
|
||||
err = UpdateFieldData(baseData, updateData, 2)
|
||||
err = UpdateFieldData(baseData, updateData, 2, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that ValidData was updated
|
||||
@ -3283,7 +3283,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test updating first row
|
||||
err := UpdateFieldData(baseData, updateData, 0)
|
||||
err := UpdateFieldData(baseData, updateData, 0, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first row was correctly merged
|
||||
@ -3298,7 +3298,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
assert.Equal(t, "new_value", result["key5"]) // New value
|
||||
|
||||
// Test updating second row
|
||||
err = UpdateFieldData(baseData, updateData, 1)
|
||||
err = UpdateFieldData(baseData, updateData, 1, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify second row was correctly merged
|
||||
@ -3356,7 +3356,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test updating
|
||||
err := UpdateFieldData(baseData, updateData, 0)
|
||||
err := UpdateFieldData(baseData, updateData, 0, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For non-dynamic fields, the update should completely replace the old value
|
||||
@ -3407,7 +3407,7 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test updating with invalid base JSON
|
||||
err := UpdateFieldData(baseData, updateData, 0)
|
||||
err := UpdateFieldData(baseData, updateData, 0, 0)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal base json")
|
||||
|
||||
@ -3416,8 +3416,126 @@ func TestUpdateFieldData(t *testing.T) {
|
||||
updateData[0].GetScalars().GetJsonData().Data[0] = []byte(`invalid json`)
|
||||
|
||||
// Test updating with invalid update JSON
|
||||
err = UpdateFieldData(baseData, updateData, 0)
|
||||
err = UpdateFieldData(baseData, updateData, 0, 0)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal update json")
|
||||
})
|
||||
|
||||
t.Run("nullable field with valid data index mapping", func(t *testing.T) {
|
||||
// Test the new logic for nullable fields where updateIdx needs to be mapped to actual data index
|
||||
// Scenario: data=[1,2,3], valid_data=[true, false, true]
|
||||
// updateIdx=1 should map to data index 0 (first valid data), updateIdx=2 should map to data index 1 (second valid data)
|
||||
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "nullable_int_field",
|
||||
FieldId: 1,
|
||||
ValidData: []bool{true, true, true}, // All base data is valid
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{100, 200, 300},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "nullable_int_field",
|
||||
FieldId: 1,
|
||||
ValidData: []bool{true, false, true}, // Only indices 0 and 2 are valid
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{999, 888, 777}, // Only indices 0 and 2 will be used
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test updating at index 1 (which maps to data index 0 due to valid_data[1] = false)
|
||||
err := UpdateFieldData(baseData, updateData, 0, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Since valid_data[1] = false, no data should be updated
|
||||
assert.Equal(t, int64(100), baseData[0].GetScalars().GetLongData().Data[0])
|
||||
assert.Equal(t, false, baseData[0].ValidData[0])
|
||||
|
||||
// Test updating at index 2 (which maps to data index 1 due to valid_data[2] = true)
|
||||
err = UpdateFieldData(baseData, updateData, 1, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Index 2 maps to data index 1 because valid_data[0] = true, valid_data[1] = false
|
||||
// So updateFieldIdx = 0 + 1 = 1, which means we use data[1] = 888
|
||||
assert.Equal(t, int64(888), baseData[0].GetScalars().GetLongData().Data[1])
|
||||
assert.Equal(t, true, baseData[0].ValidData[1])
|
||||
})
|
||||
|
||||
t.Run("nullable field with complex valid data pattern", func(t *testing.T) {
|
||||
// Test more complex pattern: data=[1,2,3,4,5], valid_data=[false, true, false, true, false]
|
||||
// This tests the index mapping logic more thoroughly
|
||||
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: "complex_nullable_field",
|
||||
FieldId: 2,
|
||||
ValidData: []bool{true, true, true, true, true},
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{1.1, 2.2, 3.3, 4.4, 5.5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: "complex_nullable_field",
|
||||
FieldId: 2,
|
||||
ValidData: []bool{false, true, false, true, false}, // Only indices 1 and 3 are valid
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{999.9, 888.8},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test updating at index 1 (valid_data[1] = true, so data[1] = 999.9)
|
||||
err := UpdateFieldData(baseData, updateData, 1, 1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, float32(999.9), baseData[0].GetScalars().GetFloatData().Data[1])
|
||||
assert.Equal(t, true, baseData[0].ValidData[0])
|
||||
|
||||
// Test updating at index 3 (valid_data[3] = true, so data[3] = 888.8)
|
||||
err = UpdateFieldData(baseData, updateData, 3, 3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, float32(888.8), baseData[0].GetScalars().GetFloatData().Data[3])
|
||||
assert.Equal(t, true, baseData[0].ValidData[1])
|
||||
|
||||
// Test updating at index 0 (valid_data[0] = false, so no data update, only ValidData update)
|
||||
err = UpdateFieldData(baseData, updateData, 2, 2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, float32(3.3), baseData[0].GetScalars().GetFloatData().Data[2]) // Should remain unchanged
|
||||
assert.Equal(t, false, baseData[0].ValidData[2])
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user