mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: Deduplicate primary keys in upsert request batch (#45249)
issue: #44320 This change adds deduplication logic to handle duplicate primary keys within a single upsert batch, keeping the last occurrence of each primary key. Key changes: - Add DeduplicateFieldData function to remove duplicate PKs from field data, supporting both Int64 and VarChar primary keys - Refactor fillFieldPropertiesBySchema into two separate functions: validateFieldDataColumns for validation and fillFieldPropertiesOnly for property filling, improving code clarity and reusability - Integrate deduplication logic in upsertTask.PreExecute to automatically deduplicate data before processing - Add comprehensive unit tests for deduplication with various PK types (Int64, VarChar) and field types (scalar, vector) - Add Python integration tests to verify end-to-end behavior --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
e9506f1d64
commit
7aed88113c
@ -235,11 +235,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// set field ID to insert field data
|
||||
err = fillFieldPropertiesBySchema(it.insertMsg.GetFieldsData(), schema.CollectionSchema)
|
||||
// Validate and set field ID to insert field data
|
||||
err = validateFieldDataColumns(it.insertMsg.GetFieldsData(), schema)
|
||||
if err != nil {
|
||||
log.Info("set fieldID to fieldData failed",
|
||||
zap.Error(err))
|
||||
log.Info("validate field data columns failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
err = fillFieldPropertiesOnly(it.insertMsg.GetFieldsData(), schema)
|
||||
if err != nil {
|
||||
log.Info("fill field properties failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -899,11 +899,15 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// set field ID to insert field data
|
||||
err = fillFieldPropertiesBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema)
|
||||
// Validate and set field ID to insert field data
|
||||
err = validateFieldDataColumns(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
|
||||
if err != nil {
|
||||
log.Warn("insert set fieldID to fieldData failed when upsert",
|
||||
zap.Error(err))
|
||||
log.Warn("validate field data columns failed when upsert", zap.Error(err))
|
||||
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
|
||||
}
|
||||
err = fillFieldPropertiesOnly(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
|
||||
if err != nil {
|
||||
log.Warn("fill field properties failed when upsert", zap.Error(err))
|
||||
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
|
||||
}
|
||||
|
||||
@ -1068,6 +1072,26 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// deduplicate upsert data to handle duplicate primary keys in the same batch
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Warn("fail to get primary field schema", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
deduplicatedFieldsData, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, it.req.GetFieldsData(), schema)
|
||||
if err != nil {
|
||||
log.Warn("fail to deduplicate upsert data", zap.Error(err))
|
||||
}
|
||||
|
||||
// dedup won't decrease numOfRows to 0
|
||||
if newNumRows > 0 && newNumRows != it.req.NumRows {
|
||||
log.Info("upsert data deduplicated",
|
||||
zap.Uint32("original_num_rows", it.req.NumRows),
|
||||
zap.Uint32("deduplicated_num_rows", newNumRows))
|
||||
it.req.FieldsData = deduplicatedFieldsData
|
||||
it.req.NumRows = newNumRows
|
||||
}
|
||||
|
||||
it.upsertMsg = &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
|
||||
@ -1051,6 +1051,7 @@ func TestUpdateTask_PreExecute_InvalidNumRows(t *testing.T) {
|
||||
}, nil).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.req.FieldsData = []*schemapb.FieldData{}
|
||||
task.req.NumRows = 0 // Invalid num_rows
|
||||
|
||||
err := task.PreExecute(context.Background())
|
||||
@ -1534,3 +1535,334 @@ func TestUpsertTask_PlanNamespace_AfterPreExecute(t *testing.T) {
|
||||
assert.Equal(t, *task.req.Namespace, *capturedPlan.Namespace)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpsertTask_Deduplicate_Int64PK(t *testing.T) {
|
||||
// Test deduplication with Int64 primary key
|
||||
primaryFieldSchema := &schemapb.FieldSchema{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
primaryFieldSchema,
|
||||
{
|
||||
Name: "float_field",
|
||||
FieldID: 101,
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
|
||||
// Create field data with duplicate IDs: [1, 2, 3, 2, 1]
|
||||
// Expected to keep last occurrence of each: [3, 2, 1] (indices 2, 3, 4)
|
||||
fieldsData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 2, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldName: "float_field",
|
||||
Type: schemapb.DataType_Float,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{1.1, 2.2, 3.3, 2.4, 1.5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(3), newNumRows)
|
||||
assert.Equal(t, 2, len(deduplicatedFields))
|
||||
|
||||
// Check deduplicated primary keys
|
||||
pkField := deduplicatedFields[0]
|
||||
pkData := pkField.GetScalars().GetLongData().GetData()
|
||||
assert.Equal(t, 3, len(pkData))
|
||||
assert.Equal(t, []int64{3, 2, 1}, pkData)
|
||||
|
||||
// Check corresponding float values (should be 3.3, 2.4, 1.5)
|
||||
floatField := deduplicatedFields[1]
|
||||
floatData := floatField.GetScalars().GetFloatData().GetData()
|
||||
assert.Equal(t, 3, len(floatData))
|
||||
assert.Equal(t, []float32{3.3, 2.4, 1.5}, floatData)
|
||||
}
|
||||
|
||||
func TestUpsertTask_Deduplicate_VarCharPK(t *testing.T) {
|
||||
// Test deduplication with VarChar primary key
|
||||
primaryFieldSchema := &schemapb.FieldSchema{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
primaryFieldSchema,
|
||||
{
|
||||
Name: "int_field",
|
||||
FieldID: 101,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
|
||||
// Create field data with duplicate IDs: ["a", "b", "c", "b", "a"]
|
||||
// Expected to keep last occurrence of each: ["c", "b", "a"] (indices 2, 3, 4)
|
||||
fieldsData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c", "b", "a"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldName: "int_field",
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{100, 200, 300, 201, 101},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(3), newNumRows)
|
||||
assert.Equal(t, 2, len(deduplicatedFields))
|
||||
|
||||
// Check deduplicated primary keys
|
||||
pkField := deduplicatedFields[0]
|
||||
pkData := pkField.GetScalars().GetStringData().GetData()
|
||||
assert.Equal(t, 3, len(pkData))
|
||||
assert.Equal(t, []string{"c", "b", "a"}, pkData)
|
||||
|
||||
// Check corresponding int64 values (should be 300, 201, 101)
|
||||
int64Field := deduplicatedFields[1]
|
||||
int64Data := int64Field.GetScalars().GetLongData().GetData()
|
||||
assert.Equal(t, 3, len(int64Data))
|
||||
assert.Equal(t, []int64{300, 201, 101}, int64Data)
|
||||
}
|
||||
|
||||
func TestUpsertTask_Deduplicate_NoDuplicates(t *testing.T) {
|
||||
// Test with no duplicates - should return original data
|
||||
primaryFieldSchema := &schemapb.FieldSchema{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
primaryFieldSchema,
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
|
||||
fieldsData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(5), newNumRows)
|
||||
assert.Equal(t, 1, len(deduplicatedFields))
|
||||
|
||||
// Should be unchanged
|
||||
pkField := deduplicatedFields[0]
|
||||
pkData := pkField.GetScalars().GetLongData().GetData()
|
||||
assert.Equal(t, []int64{1, 2, 3, 4, 5}, pkData)
|
||||
}
|
||||
|
||||
func TestUpsertTask_Deduplicate_WithVector(t *testing.T) {
|
||||
// Test deduplication with vector field
|
||||
primaryFieldSchema := &schemapb.FieldSchema{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
primaryFieldSchema,
|
||||
{
|
||||
Name: "vector",
|
||||
FieldID: 101,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
|
||||
dim := 4
|
||||
// Create field data with duplicate IDs: [1, 2, 1]
|
||||
// Expected to keep indices [1, 2] (last occurrence of 2, last occurrence of 1)
|
||||
fieldsData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldName: "vector",
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{
|
||||
1.0, 1.1, 1.2, 1.3, // vector for ID 1 (first occurrence)
|
||||
2.0, 2.1, 2.2, 2.3, // vector for ID 2
|
||||
1.4, 1.5, 1.6, 1.7, // vector for ID 1 (second occurrence - keep this)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(2), newNumRows)
|
||||
assert.Equal(t, 2, len(deduplicatedFields))
|
||||
|
||||
// Check deduplicated primary keys
|
||||
pkField := deduplicatedFields[0]
|
||||
pkData := pkField.GetScalars().GetLongData().GetData()
|
||||
assert.Equal(t, 2, len(pkData))
|
||||
assert.Equal(t, []int64{2, 1}, pkData)
|
||||
|
||||
// Check corresponding vector (should keep vectors for ID 2 and ID 1's last occurrence)
|
||||
vectorField := deduplicatedFields[1]
|
||||
vectorData := vectorField.GetVectors().GetFloatVector().GetData()
|
||||
assert.Equal(t, 8, len(vectorData)) // 2 vectors * 4 dimensions
|
||||
expectedVector := []float32{
|
||||
2.0, 2.1, 2.2, 2.3, // vector for ID 2
|
||||
1.4, 1.5, 1.6, 1.7, // vector for ID 1 (last occurrence)
|
||||
}
|
||||
assert.Equal(t, expectedVector, vectorData)
|
||||
}
|
||||
|
||||
func TestUpsertTask_Deduplicate_EmptyData(t *testing.T) {
|
||||
// Test with empty data
|
||||
primaryFieldSchema := &schemapb.FieldSchema{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
primaryFieldSchema,
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
|
||||
fieldsData := []*schemapb.FieldData{}
|
||||
|
||||
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(0), newNumRows)
|
||||
assert.Equal(t, 0, len(deduplicatedFields))
|
||||
}
|
||||
|
||||
func TestUpsertTask_Deduplicate_MissingPrimaryKey(t *testing.T) {
|
||||
// Test with missing primary key field
|
||||
primaryFieldSchema := &schemapb.FieldSchema{
|
||||
Name: "id",
|
||||
FieldID: 100,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
}
|
||||
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
primaryFieldSchema,
|
||||
{
|
||||
Name: "other_field",
|
||||
FieldID: 101,
|
||||
DataType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
|
||||
fieldsData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "other_field",
|
||||
Type: schemapb.DataType_Float,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{1.1, 2.2},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
||||
assert.Error(t, err)
|
||||
// validateFieldDataColumns will fail first due to column count mismatch
|
||||
// or the function will fail when trying to find primary key
|
||||
assert.True(t, err != nil)
|
||||
}
|
||||
|
||||
@ -1046,6 +1046,103 @@ func parsePrimaryFieldData2IDs(fieldData *schemapb.FieldData) (*schemapb.IDs, er
|
||||
return primaryData, nil
|
||||
}
|
||||
|
||||
// findLastOccurrenceIndices finds indices of last occurrences for each unique ID
|
||||
func findLastOccurrenceIndices[T comparable](ids []T) []int {
|
||||
lastOccurrence := make(map[T]int, len(ids))
|
||||
for idx, id := range ids {
|
||||
lastOccurrence[id] = idx
|
||||
}
|
||||
|
||||
keepIndices := make([]int, 0, len(lastOccurrence))
|
||||
for idx, id := range ids {
|
||||
if lastOccurrence[id] == idx {
|
||||
keepIndices = append(keepIndices, idx)
|
||||
}
|
||||
}
|
||||
return keepIndices
|
||||
}
|
||||
|
||||
// DeduplicateFieldData removes duplicate primary keys from field data,
|
||||
// keeping the last occurrence of each ID
|
||||
func DeduplicateFieldData(primaryFieldSchema *schemapb.FieldSchema, fieldsData []*schemapb.FieldData, schema *schemaInfo) ([]*schemapb.FieldData, uint32, error) {
|
||||
if len(fieldsData) == 0 {
|
||||
return fieldsData, 0, nil
|
||||
}
|
||||
|
||||
if err := fillFieldPropertiesOnly(fieldsData, schema); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// find primary field data
|
||||
var primaryFieldData *schemapb.FieldData
|
||||
for _, field := range fieldsData {
|
||||
if field.GetFieldName() == primaryFieldSchema.GetName() {
|
||||
primaryFieldData = field
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if primaryFieldData == nil {
|
||||
return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldSchema.GetName()))
|
||||
}
|
||||
|
||||
// get row count
|
||||
var numRows int
|
||||
switch primaryFieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := primaryFieldData.GetScalars()
|
||||
switch scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_LongData:
|
||||
numRows = len(scalarField.GetLongData().GetData())
|
||||
case *schemapb.ScalarField_StringData:
|
||||
numRows = len(scalarField.GetStringData().GetData())
|
||||
default:
|
||||
return nil, 0, merr.WrapErrParameterInvalidMsg("unsupported primary key type")
|
||||
}
|
||||
default:
|
||||
return nil, 0, merr.WrapErrParameterInvalidMsg("primary field must be scalar type")
|
||||
}
|
||||
|
||||
if numRows == 0 {
|
||||
return fieldsData, 0, nil
|
||||
}
|
||||
|
||||
// build map to track last occurrence of each primary key
|
||||
var keepIndices []int
|
||||
switch primaryFieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
scalarField := primaryFieldData.GetScalars()
|
||||
switch scalarField.Data.(type) {
|
||||
case *schemapb.ScalarField_LongData:
|
||||
// for Int64 primary keys
|
||||
intIDs := scalarField.GetLongData().GetData()
|
||||
keepIndices = findLastOccurrenceIndices(intIDs)
|
||||
|
||||
case *schemapb.ScalarField_StringData:
|
||||
// for VarChar primary keys
|
||||
strIDs := scalarField.GetStringData().GetData()
|
||||
keepIndices = findLastOccurrenceIndices(strIDs)
|
||||
}
|
||||
}
|
||||
|
||||
// if no duplicates found, return original data
|
||||
if len(keepIndices) == numRows {
|
||||
return fieldsData, uint32(numRows), nil
|
||||
}
|
||||
|
||||
log.Info("duplicate primary keys detected in upsert request, deduplicating",
|
||||
zap.Int("original_rows", numRows),
|
||||
zap.Int("deduplicated_rows", len(keepIndices)))
|
||||
|
||||
// use typeutil.AppendFieldData to rebuild field data with deduplicated rows
|
||||
result := typeutil.PrepareResultFieldData(fieldsData, int64(len(keepIndices)))
|
||||
for _, idx := range keepIndices {
|
||||
typeutil.AppendFieldData(result, fieldsData, int64(idx))
|
||||
}
|
||||
|
||||
return result, uint32(len(keepIndices)), nil
|
||||
}
|
||||
|
||||
// autoGenPrimaryFieldData generate primary data when autoID == true
|
||||
func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}) (*schemapb.FieldData, error) {
|
||||
var fieldData schemapb.FieldData
|
||||
@ -1105,52 +1202,34 @@ func autoGenDynamicFieldData(data [][]byte) *schemapb.FieldData {
|
||||
}
|
||||
}
|
||||
|
||||
// fillFieldPropertiesBySchema set fieldID to fieldData according FieldSchemas
|
||||
func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error {
|
||||
fieldName2Schema := make(map[string]*schemapb.FieldSchema)
|
||||
|
||||
// validateFieldDataColumns validates that all required fields are present and no unknown fields exist.
|
||||
// It checks:
|
||||
// 1. The number of columns matches the expected count (excluding BM25 output fields)
|
||||
// 2. All field names exist in the schema
|
||||
// Returns detailed error message listing expected and provided fields if validation fails.
|
||||
func validateFieldDataColumns(columns []*schemapb.FieldData, schema *schemaInfo) error {
|
||||
expectColumnNum := 0
|
||||
|
||||
// Count expected columns
|
||||
for _, field := range schema.GetFields() {
|
||||
fieldName2Schema[field.Name] = field
|
||||
if !typeutil.IsBM25FunctionOutputField(field, schema) {
|
||||
if !typeutil.IsBM25FunctionOutputField(field, schema.CollectionSchema) {
|
||||
expectColumnNum++
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range schema.GetStructArrayFields() {
|
||||
for _, field := range structField.GetFields() {
|
||||
fieldName2Schema[field.Name] = field
|
||||
expectColumnNum++
|
||||
}
|
||||
expectColumnNum += len(structField.GetFields())
|
||||
}
|
||||
|
||||
// Validate column count
|
||||
if len(columns) != expectColumnNum {
|
||||
return fmt.Errorf("len(columns) mismatch the expectColumnNum, expectColumnNum: %d, len(columns): %d",
|
||||
expectColumnNum, len(columns))
|
||||
}
|
||||
|
||||
// Validate field existence using schemaHelper
|
||||
for _, fieldData := range columns {
|
||||
if fieldSchema, ok := fieldName2Schema[fieldData.FieldName]; ok {
|
||||
fieldData.FieldId = fieldSchema.FieldID
|
||||
fieldData.Type = fieldSchema.DataType
|
||||
|
||||
// Set the ElementType because it may not be set in the insert request.
|
||||
if fieldData.Type == schemapb.DataType_Array {
|
||||
fd, ok := fieldData.Field.(*schemapb.FieldData_Scalars)
|
||||
if !ok || fd.Scalars.GetArrayData() == nil {
|
||||
return fmt.Errorf("field convert FieldData_Scalars fail in fieldData, fieldName: %s,"+
|
||||
" collectionName:%s", fieldData.FieldName, schema.Name)
|
||||
}
|
||||
fd.Scalars.GetArrayData().ElementType = fieldSchema.ElementType
|
||||
} else if fieldData.Type == schemapb.DataType_ArrayOfVector {
|
||||
fd, ok := fieldData.Field.(*schemapb.FieldData_Vectors)
|
||||
if !ok || fd.Vectors.GetVectorArray() == nil {
|
||||
return fmt.Errorf("field convert FieldData_Vectors fail in fieldData, fieldName: %s,"+
|
||||
" collectionName:%s", fieldData.FieldName, schema.Name)
|
||||
}
|
||||
fd.Vectors.GetVectorArray().ElementType = fieldSchema.ElementType
|
||||
}
|
||||
} else {
|
||||
_, err := schema.schemaHelper.GetFieldFromNameDefaultJSON(fieldData.FieldName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName)
|
||||
}
|
||||
}
|
||||
@ -1158,6 +1237,41 @@ func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb
|
||||
return nil
|
||||
}
|
||||
|
||||
// fillFieldPropertiesOnly fills field properties (FieldId, Type, ElementType) from schema.
|
||||
// It assumes that columns have been validated and does not perform validation.
|
||||
// Use validateFieldDataColumns before calling this function if validation is needed.
|
||||
func fillFieldPropertiesOnly(columns []*schemapb.FieldData, schema *schemaInfo) error {
|
||||
for _, fieldData := range columns {
|
||||
// Use schemaHelper to get field schema, automatically handles dynamic fields
|
||||
fieldSchema, err := schema.schemaHelper.GetFieldFromNameDefaultJSON(fieldData.FieldName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName)
|
||||
}
|
||||
|
||||
fieldData.FieldId = fieldSchema.FieldID
|
||||
fieldData.Type = fieldSchema.DataType
|
||||
|
||||
// Set the ElementType because it may not be set in the insert request.
|
||||
if fieldData.Type == schemapb.DataType_Array {
|
||||
fd, ok := fieldData.Field.(*schemapb.FieldData_Scalars)
|
||||
if !ok || fd.Scalars.GetArrayData() == nil {
|
||||
return fmt.Errorf("field convert FieldData_Scalars fail in fieldData, fieldName: %s, collectionName: %s",
|
||||
fieldData.FieldName, schema.Name)
|
||||
}
|
||||
fd.Scalars.GetArrayData().ElementType = fieldSchema.ElementType
|
||||
} else if fieldData.Type == schemapb.DataType_ArrayOfVector {
|
||||
fd, ok := fieldData.Field.(*schemapb.FieldData_Vectors)
|
||||
if !ok || fd.Vectors.GetVectorArray() == nil {
|
||||
return fmt.Errorf("field convert FieldData_Vectors fail in fieldData, fieldName: %s, collectionName: %s",
|
||||
fieldData.FieldName, schema.Name)
|
||||
}
|
||||
fd.Vectors.GetVectorArray().ElementType = fieldSchema.ElementType
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateUsername(username string) error {
|
||||
username = strings.TrimSpace(username)
|
||||
|
||||
|
||||
@ -606,28 +606,64 @@ func TestValidateMultipleVectorFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFillFieldIDBySchema(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{}
|
||||
columns := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "TestFillFieldIDBySchema",
|
||||
},
|
||||
}
|
||||
|
||||
// length mismatch
|
||||
assert.Error(t, fillFieldPropertiesBySchema(columns, schema))
|
||||
schema = &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
t.Run("column count mismatch", func(t *testing.T) {
|
||||
collSchema := &schemapb.CollectionSchema{}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
columns := []*schemapb.FieldData{
|
||||
{
|
||||
Name: "TestFillFieldIDBySchema",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
FieldID: 1,
|
||||
FieldName: "TestFillFieldIDBySchema",
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.NoError(t, fillFieldPropertiesBySchema(columns, schema))
|
||||
assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Int64, columns[0].Type)
|
||||
assert.Equal(t, int64(1), columns[0].FieldId)
|
||||
}
|
||||
// Validation should fail due to column count mismatch
|
||||
assert.Error(t, validateFieldDataColumns(columns, schema))
|
||||
})
|
||||
|
||||
t.Run("successful validation and fill", func(t *testing.T) {
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "TestFillFieldIDBySchema",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
FieldID: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
columns := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "TestFillFieldIDBySchema",
|
||||
},
|
||||
}
|
||||
// Validation should succeed
|
||||
assert.NoError(t, validateFieldDataColumns(columns, schema))
|
||||
// Fill properties should succeed
|
||||
assert.NoError(t, fillFieldPropertiesOnly(columns, schema))
|
||||
assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Int64, columns[0].Type)
|
||||
assert.Equal(t, int64(1), columns[0].FieldId)
|
||||
})
|
||||
|
||||
t.Run("field not in schema", func(t *testing.T) {
|
||||
collSchema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "FieldA",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
FieldID: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
schema := newSchemaInfo(collSchema)
|
||||
columns := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "FieldB",
|
||||
},
|
||||
}
|
||||
// Validation should fail because FieldB is not in schema
|
||||
err := validateFieldDataColumns(columns, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not exist in collection schema")
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateUsername(t *testing.T) {
|
||||
|
||||
@ -550,4 +550,343 @@ class TestMilvusClientUpsertValid(TestMilvusClientV2Base):
|
||||
self.release_partitions(client, collection_name, partition_name)
|
||||
self.drop_partition(client, collection_name, partition_name)
|
||||
if self.has_collection(client, collection_name)[0]:
|
||||
self.drop_collection(client, collection_name)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
|
||||
class TestMilvusClientUpsertDedup(TestMilvusClientV2Base):
|
||||
"""Test case for upsert deduplication functionality"""
|
||||
|
||||
@pytest.fixture(scope="function", params=["COSINE", "L2"])
|
||||
def metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_upsert_dedup_int64_pk(self):
|
||||
"""
|
||||
target: test upsert with duplicate int64 primary keys in same batch
|
||||
method:
|
||||
1. create collection with int64 primary key
|
||||
2. upsert data with duplicate primary keys [1, 2, 3, 2, 1]
|
||||
3. query to verify only last occurrence is kept
|
||||
expected: only 3 unique records exist, with data from last occurrence
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
|
||||
# 1. create collection
|
||||
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
|
||||
# 2. upsert data with duplicate PKs: [1, 2, 3, 2, 1]
|
||||
# Expected: keep last occurrence -> [3, 2, 1] at indices [2, 3, 4]
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [
|
||||
{default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 1.0, default_string_field_name: "str_1_first"},
|
||||
{default_primary_key_field_name: 2, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 2.0, default_string_field_name: "str_2_first"},
|
||||
{default_primary_key_field_name: 3, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 3.0, default_string_field_name: "str_3"},
|
||||
{default_primary_key_field_name: 2, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 2.1, default_string_field_name: "str_2_last"},
|
||||
{default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 1.1, default_string_field_name: "str_1_last"},
|
||||
]
|
||||
|
||||
results = self.upsert(client, collection_name, rows)[0]
|
||||
# After deduplication, should only have 3 records
|
||||
assert results['upsert_count'] == 3
|
||||
|
||||
# 3. query to verify deduplication - should have only 3 unique records
|
||||
query_results = self.query(client, collection_name, filter="id >= 0")[0]
|
||||
assert len(query_results) == 3
|
||||
|
||||
# Verify that last occurrence data is kept
|
||||
id_to_data = {item['id']: item for item in query_results}
|
||||
assert 1 in id_to_data
|
||||
assert 2 in id_to_data
|
||||
assert 3 in id_to_data
|
||||
|
||||
# Check that data from last occurrence is preserved
|
||||
assert id_to_data[1]['float'] == 1.1
|
||||
assert id_to_data[1]['varchar'] == "str_1_last"
|
||||
assert id_to_data[2]['float'] == 2.1
|
||||
assert id_to_data[2]['varchar'] == "str_2_last"
|
||||
assert id_to_data[3]['float'] == 3.0
|
||||
assert id_to_data[3]['varchar'] == "str_3"
|
||||
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_upsert_dedup_varchar_pk(self):
|
||||
"""
|
||||
target: test upsert with duplicate varchar primary keys in same batch
|
||||
method:
|
||||
1. create collection with varchar primary key
|
||||
2. upsert data with duplicate primary keys ["a", "b", "c", "b", "a"]
|
||||
3. query to verify only last occurrence is kept
|
||||
expected: only 3 unique records exist, with data from last occurrence
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
|
||||
# 1. create collection with varchar primary key
|
||||
schema = self.create_schema(client, enable_dynamic_field=True)[0]
|
||||
schema.add_field("id", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
|
||||
schema.add_field(default_vector_field_name, DataType.FLOAT_VECTOR, dim=default_dim)
|
||||
schema.add_field("age", DataType.INT64)
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(default_vector_field_name, metric_type="COSINE")
|
||||
self.create_collection(client, collection_name, default_dim, schema=schema,
|
||||
index_params=index_params, consistency_level="Strong")
|
||||
|
||||
# 2. upsert data with duplicate PKs: ["a", "b", "c", "b", "a"]
|
||||
# Expected: keep last occurrence -> ["c", "b", "a"] at indices [2, 3, 4]
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [
|
||||
{"id": "a", default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
"age": 10},
|
||||
{"id": "b", default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
"age": 20},
|
||||
{"id": "c", default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
"age": 30},
|
||||
{"id": "b", default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
"age": 21},
|
||||
{"id": "a", default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
"age": 11},
|
||||
]
|
||||
|
||||
results = self.upsert(client, collection_name, rows)[0]
|
||||
# After deduplication, should only have 3 records
|
||||
assert results['upsert_count'] == 3
|
||||
|
||||
# 3. query to verify deduplication
|
||||
query_results = self.query(client, collection_name, filter='id in ["a", "b", "c"]')[0]
|
||||
assert len(query_results) == 3
|
||||
|
||||
# Verify that last occurrence data is kept
|
||||
id_to_data = {item['id']: item for item in query_results}
|
||||
assert "a" in id_to_data
|
||||
assert "b" in id_to_data
|
||||
assert "c" in id_to_data
|
||||
|
||||
# Check that data from last occurrence is preserved
|
||||
assert id_to_data["a"]["age"] == 11
|
||||
assert id_to_data["b"]["age"] == 21
|
||||
assert id_to_data["c"]["age"] == 30
|
||||
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_upsert_dedup_all_duplicates(self):
|
||||
"""
|
||||
target: test upsert when all records have same primary key
|
||||
method:
|
||||
1. create collection
|
||||
2. upsert 5 records with same primary key
|
||||
3. query to verify only 1 record exists
|
||||
expected: only 1 record exists with data from last occurrence
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
|
||||
# 1. create collection
|
||||
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
|
||||
# 2. upsert data where all have same PK (id=1)
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [
|
||||
{default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: i * 1.0, default_string_field_name: f"version_{i}"}
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
results = self.upsert(client, collection_name, rows)[0]
|
||||
# After deduplication, should only have 1 record
|
||||
assert results['upsert_count'] == 1
|
||||
|
||||
# 3. query to verify only 1 record exists
|
||||
query_results = self.query(client, collection_name, filter="id == 1")[0]
|
||||
assert len(query_results) == 1
|
||||
|
||||
# Verify it's the last occurrence (i=4)
|
||||
assert query_results[0]['float'] == 4.0
|
||||
assert query_results[0]['varchar'] == "version_4"
|
||||
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_upsert_dedup_no_duplicates(self):
|
||||
"""
|
||||
target: test upsert with no duplicate primary keys
|
||||
method:
|
||||
1. create collection
|
||||
2. upsert data with unique primary keys
|
||||
3. query to verify all records exist
|
||||
expected: all records exist as-is
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
|
||||
# 1. create collection
|
||||
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
|
||||
# 2. upsert data with unique PKs
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
nb = 10
|
||||
rows = [
|
||||
{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: i * 1.0, default_string_field_name: str(i)}
|
||||
for i in range(nb)
|
||||
]
|
||||
|
||||
results = self.upsert(client, collection_name, rows)[0]
|
||||
# No deduplication should occur
|
||||
assert results['upsert_count'] == nb
|
||||
|
||||
# 3. query to verify all records exist
|
||||
query_results = self.query(client, collection_name, filter=f"id >= 0")[0]
|
||||
assert len(query_results) == nb
|
||||
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_milvus_client_upsert_dedup_large_batch(self):
|
||||
"""
|
||||
target: test upsert deduplication with large batch
|
||||
method:
|
||||
1. create collection
|
||||
2. upsert large batch with 50% duplicate primary keys
|
||||
3. query to verify correct number of records
|
||||
expected: only unique records exist
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
|
||||
# 1. create collection
|
||||
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
|
||||
# 2. upsert large batch where each ID appears twice
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
nb = 500
|
||||
unique_ids = nb // 2 # 250 unique IDs
|
||||
|
||||
rows = []
|
||||
for i in range(nb):
|
||||
pk = i % unique_ids # This creates duplicates: 0,1,2...249,0,1,2...249
|
||||
rows.append({
|
||||
default_primary_key_field_name: pk,
|
||||
default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: float(i), # Different value for each row
|
||||
default_string_field_name: f"batch_{i}"
|
||||
})
|
||||
|
||||
results = self.upsert(client, collection_name, rows)[0]
|
||||
# After deduplication, should only have unique_ids records
|
||||
assert results['upsert_count'] == unique_ids
|
||||
|
||||
# 3. query to verify correct number of records
|
||||
query_results = self.query(client, collection_name, filter=f"id >= 0", limit=1000)[0]
|
||||
assert len(query_results) == unique_ids
|
||||
|
||||
# Verify that last occurrence is kept (should have higher float values)
|
||||
for item in query_results:
|
||||
pk = item['id']
|
||||
# Last occurrence of pk is at index (pk + unique_ids)
|
||||
expected_float = float(pk + unique_ids)
|
||||
assert item['float'] == expected_float
|
||||
assert item['varchar'] == f"batch_{pk + unique_ids}"
|
||||
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_upsert_dedup_with_partition(self):
|
||||
"""
|
||||
target: test upsert deduplication works correctly with partitions
|
||||
method:
|
||||
1. create collection with partition
|
||||
2. upsert data with duplicates to specific partition
|
||||
3. query to verify deduplication in partition
|
||||
expected: deduplication works within partition
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
partition_name = cf.gen_unique_str("partition")
|
||||
|
||||
# 1. create collection and partition
|
||||
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
self.create_partition(client, collection_name, partition_name)
|
||||
|
||||
# 2. upsert data with duplicates to partition
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [
|
||||
{default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 1.0, default_string_field_name: "first"},
|
||||
{default_primary_key_field_name: 2, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 2.0, default_string_field_name: "unique"},
|
||||
{default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
default_float_field_name: 1.1, default_string_field_name: "last"},
|
||||
]
|
||||
|
||||
results = self.upsert(client, collection_name, rows, partition_name=partition_name)[0]
|
||||
assert results['upsert_count'] == 2
|
||||
|
||||
# 3. query partition to verify deduplication
|
||||
query_results = self.query(client, collection_name, filter="id >= 0",
|
||||
partition_names=[partition_name])[0]
|
||||
assert len(query_results) == 2
|
||||
|
||||
# Verify correct data
|
||||
id_to_data = {item['id']: item for item in query_results}
|
||||
assert id_to_data[1]['float'] == 1.1
|
||||
assert id_to_data[1]['varchar'] == "last"
|
||||
assert id_to_data[2]['float'] == 2.0
|
||||
assert id_to_data[2]['varchar'] == "unique"
|
||||
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_upsert_dedup_with_vectors(self):
|
||||
"""
|
||||
target: test upsert deduplication preserves correct vector data
|
||||
method:
|
||||
1. create collection
|
||||
2. upsert data with duplicate PKs but different vectors
|
||||
3. search to verify correct vector is preserved
|
||||
expected: vector from last occurrence is preserved
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||
|
||||
# 1. create collection
|
||||
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
|
||||
# 2. upsert data with duplicate PK=1 but different vectors
|
||||
# Create distinctly different vectors for easy verification
|
||||
first_vector = [1.0] * default_dim # All 1.0
|
||||
last_vector = [2.0] * default_dim # All 2.0
|
||||
|
||||
rows = [
|
||||
{default_primary_key_field_name: 1, default_vector_field_name: first_vector,
|
||||
default_float_field_name: 1.0, default_string_field_name: "first"},
|
||||
{default_primary_key_field_name: 2, default_vector_field_name: [0.5] * default_dim,
|
||||
default_float_field_name: 2.0, default_string_field_name: "unique"},
|
||||
{default_primary_key_field_name: 1, default_vector_field_name: last_vector,
|
||||
default_float_field_name: 1.1, default_string_field_name: "last"},
|
||||
]
|
||||
|
||||
results = self.upsert(client, collection_name, rows)[0]
|
||||
assert results['upsert_count'] == 2
|
||||
|
||||
# 3. query to get vector data
|
||||
query_results = self.query(client, collection_name, filter="id == 1",
|
||||
output_fields=["id", "vector", "float", "varchar"])[0]
|
||||
assert len(query_results) == 1
|
||||
|
||||
# Verify it's the last occurrence with last_vector
|
||||
result = query_results[0]
|
||||
assert result['float'] == 1.1
|
||||
assert result['varchar'] == "last"
|
||||
# Vector should be last_vector (all 2.0)
|
||||
assert all(abs(v - 2.0) < 0.001 for v in result['vector'])
|
||||
|
||||
self.drop_collection(client, collection_name)
|
||||
@ -2077,7 +2077,7 @@ class TestUpsertInvalid(TestcaseBase):
|
||||
log.debug(f"dirty_i: {dirty_i}")
|
||||
for i in range(len(data)):
|
||||
if data[i][dirty_i].__class__ is int:
|
||||
tmp = data[i][0]
|
||||
tmp = data[i][dirty_i]
|
||||
data[i][dirty_i] = "iamstring"
|
||||
error = {ct.err_code: 999, ct.err_msg: "The Input data type is inconsistent with defined schema"}
|
||||
collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user