mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 09:08:43 +08:00
enhance: Support partial field updates with upsert API (#42877)
issue: #29735 Implement partial field update functionality for upsert operations, supporting scalar, vector, and dynamic JSON fields without requiring all collection fields. Changes: - Add queryPreExecute to retrieve existing records before upsert - Implement UpdateFieldData function for merging data - Add IDsChecker utility for efficient primary key lookups - Fix JSON data creation in tests using proper map marshaling - Add test cases for partial updates of different field types Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
b602b4187d
commit
d3c95eaa77
@ -318,8 +318,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
|
||||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6 h1:qTBOTsZ3OwEXkrHRqPn562ddkDqeToIY6CstLIaVQYs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17 h1:zyrKuc0rwT5xWIFkZr/bFWXXYbYvSBMT3iFITnaR8IE=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus/pkg/v2 v2.0.0-20250319085209-5a6b4e56d59e h1:VCr43pG4efacDbM4au70fh8/5hNTftoWzm1iEumvDWM=
|
||||
|
||||
@ -52,6 +52,7 @@ type columnBasedDataOption struct {
|
||||
collName string
|
||||
partitionName string
|
||||
columns []column.Column
|
||||
partialUpdate bool
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WriteBackPKs(_ *entity.Schema, _ column.Column) error {
|
||||
@ -253,6 +254,11 @@ func (opt *columnBasedDataOption) WithPartition(partitionName string) *columnBas
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) WithPartialUpdate(partialUpdate bool) *columnBasedDataOption {
|
||||
opt.partialUpdate = partialUpdate
|
||||
return opt
|
||||
}
|
||||
|
||||
func (opt *columnBasedDataOption) CollectionName() string {
|
||||
return opt.collName
|
||||
}
|
||||
@ -282,6 +288,7 @@ func (opt *columnBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvu
|
||||
FieldsData: fieldsData,
|
||||
NumRows: uint32(rowNum),
|
||||
SchemaTimestamp: coll.UpdateTimestamp,
|
||||
PartialUpdate: opt.partialUpdate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -340,6 +347,7 @@ func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb
|
||||
PartitionName: opt.partitionName,
|
||||
FieldsData: fieldsData,
|
||||
NumRows: uint32(rowNum),
|
||||
PartialUpdate: opt.partialUpdate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
11
go.mod
11
go.mod
@ -86,6 +86,13 @@ require (
|
||||
mosn.io/holmes v1.0.2
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/gopherjs/gopherjs v1.12.80 // indirect
|
||||
github.com/jtolds/gls v4.20.0+incompatible // indirect
|
||||
github.com/smartystreets/assertions v1.2.0 // indirect
|
||||
github.com/smartystreets/goconvey v1.7.2 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.115.0 // indirect
|
||||
cloud.google.com/go/auth v0.6.1 // indirect
|
||||
@ -164,7 +171,6 @@ require (
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.5 // indirect
|
||||
github.com/gopherjs/gopherjs v1.12.80 // indirect
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
|
||||
@ -177,7 +183,6 @@ require (
|
||||
github.com/ianlancetaylor/cgosymbolizer v0.0.0-20221217025313-27d3c9f66b6a // indirect
|
||||
github.com/jonboulle/clockwork v0.2.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/jtolds/gls v4.20.0+incompatible // indirect
|
||||
github.com/klauspost/asmfmt v1.3.2 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
@ -220,8 +225,6 @@ require (
|
||||
github.com/shirou/gopsutil/v3 v3.23.7 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/smartystreets/assertions v1.2.0 // indirect
|
||||
github.com/smartystreets/goconvey v1.7.2 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/afero v1.6.0 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
|
||||
@ -755,7 +755,7 @@ func (h *HandlersV1) insert(c *gin.Context) {
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema)
|
||||
err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema, false)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
@ -765,7 +765,7 @@ func (h *HandlersV1) insert(c *gin.Context) {
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
insertReq := req.(*milvuspb.InsertRequest)
|
||||
insertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, true)
|
||||
insertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, true, false)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
@ -831,6 +831,7 @@ func (h *HandlersV1) upsert(c *gin.Context) {
|
||||
httpReq.DbName = singleUpsertReq.DbName
|
||||
httpReq.CollectionName = singleUpsertReq.CollectionName
|
||||
httpReq.Data = []map[string]interface{}{singleUpsertReq.Data}
|
||||
httpReq.PartialUpdate = singleUpsertReq.PartialUpdate
|
||||
}
|
||||
if httpReq.CollectionName == "" || httpReq.Data == nil {
|
||||
log.Warn("high level restful api, upsert require parameter: [collectionName, data], but miss")
|
||||
@ -844,6 +845,7 @@ func (h *HandlersV1) upsert(c *gin.Context) {
|
||||
DbName: httpReq.DbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
NumRows: uint32(len(httpReq.Data)),
|
||||
PartialUpdate: httpReq.PartialUpdate,
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
username, _ := c.Get(ContextUsername)
|
||||
@ -861,7 +863,7 @@ func (h *HandlersV1) upsert(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema)
|
||||
err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema, httpReq.PartialUpdate)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
@ -871,7 +873,7 @@ func (h *HandlersV1) upsert(c *gin.Context) {
|
||||
return nil, RestRequestInterceptorErr
|
||||
}
|
||||
upsertReq := req.(*milvuspb.UpsertRequest)
|
||||
upsertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, false)
|
||||
upsertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, false, httpReq.PartialUpdate)
|
||||
if err != nil {
|
||||
log.Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
|
||||
@ -981,7 +981,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
var validDataMap map[string][]bool
|
||||
err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema)
|
||||
err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema, false)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, fail to deal with insert data", zap.Error(err), zap.String("body", string(body.([]byte))))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
@ -992,7 +992,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
}
|
||||
|
||||
req.NumRows = uint32(len(httpReq.Data))
|
||||
req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, true)
|
||||
req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, true, false)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
@ -1045,6 +1045,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
DbName: dbName,
|
||||
CollectionName: httpReq.CollectionName,
|
||||
PartitionName: httpReq.PartitionName,
|
||||
PartialUpdate: httpReq.PartialUpdate,
|
||||
// PartitionName: "_default",
|
||||
}
|
||||
c.Set(ContextRequest, req)
|
||||
@ -1055,7 +1056,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
}
|
||||
body, _ := c.Get(gin.BodyBytesKey)
|
||||
var validDataMap map[string][]bool
|
||||
err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema)
|
||||
err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema, httpReq.PartialUpdate)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
@ -1066,7 +1067,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN
|
||||
}
|
||||
|
||||
req.NumRows = uint32(len(httpReq.Data))
|
||||
req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, false)
|
||||
req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, false, httpReq.PartialUpdate)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err))
|
||||
HTTPAbortReturn(c, http.StatusOK, gin.H{
|
||||
|
||||
@ -71,12 +71,14 @@ type UpsertReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
Data []map[string]interface{} `json:"data" validate:"required"`
|
||||
PartialUpdate bool `json:"partialUpdate"`
|
||||
}
|
||||
|
||||
type SingleUpsertReq struct {
|
||||
DbName string `json:"dbName"`
|
||||
CollectionName string `json:"collectionName" validate:"required"`
|
||||
Data map[string]interface{} `json:"data" validate:"required"`
|
||||
PartialUpdate bool `json:"partialUpdate"`
|
||||
}
|
||||
|
||||
type SearchReq struct {
|
||||
|
||||
@ -272,6 +272,7 @@ type CollectionDataReq struct {
|
||||
CollectionName string `json:"collectionName" binding:"required"`
|
||||
PartitionName string `json:"partitionName"`
|
||||
Data []map[string]interface{} `json:"data" binding:"required"`
|
||||
PartialUpdate bool `json:"partialUpdate"`
|
||||
}
|
||||
|
||||
func (req *CollectionDataReq) GetDbName() string { return req.DbName }
|
||||
|
||||
@ -284,7 +284,7 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H {
|
||||
|
||||
// --------------------- insert param --------------------- //
|
||||
|
||||
func checkAndSetData(body []byte, collSchema *schemapb.CollectionSchema) (error, []map[string]interface{}, map[string][]bool) {
|
||||
func checkAndSetData(body []byte, collSchema *schemapb.CollectionSchema, partialUpdate bool) (error, []map[string]interface{}, map[string][]bool) {
|
||||
var reallyDataArray []map[string]interface{}
|
||||
validDataMap := make(map[string][]bool)
|
||||
dataResult := gjson.GetBytes(body, HTTPRequestData)
|
||||
@ -321,7 +321,14 @@ func checkAndSetData(body []byte, collSchema *schemapb.CollectionSchema) (error,
|
||||
}
|
||||
}
|
||||
|
||||
dataString := gjson.Get(data.Raw, fieldName).String()
|
||||
// For partial update, check if field exists in the data
|
||||
fieldValue := gjson.Get(data.Raw, fieldName)
|
||||
if partialUpdate && !fieldValue.Exists() {
|
||||
// Skip fields that are not provided in partial update
|
||||
continue
|
||||
}
|
||||
|
||||
dataString := fieldValue.String()
|
||||
// if has pass pk than just to try to set it
|
||||
if field.IsPrimaryKey && field.AutoID && len(dataString) == 0 {
|
||||
continue
|
||||
@ -732,7 +739,7 @@ func convertToIntArray(dataType schemapb.DataType, arr interface{}) []int32 {
|
||||
return res
|
||||
}
|
||||
|
||||
func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, sch *schemapb.CollectionSchema, inInsert bool) ([]*schemapb.FieldData, error) {
|
||||
func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, sch *schemapb.CollectionSchema, inInsert bool, partialUpdate bool) ([]*schemapb.FieldData, error) {
|
||||
rowsLen := len(rows)
|
||||
if rowsLen == 0 {
|
||||
return []*schemapb.FieldData{}, errors.New("no row need to be convert to columns")
|
||||
@ -810,11 +817,12 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
||||
IsDynamic: field.IsDynamic,
|
||||
}
|
||||
}
|
||||
if len(nameDims) == 0 && len(sch.Functions) == 0 {
|
||||
if len(nameDims) == 0 && len(sch.Functions) == 0 && !partialUpdate {
|
||||
return nil, fmt.Errorf("collection: %s has no vector field or functions", sch.Name)
|
||||
}
|
||||
|
||||
dynamicCol := make([][]byte, 0, rowsLen)
|
||||
fieldLen := make(map[string]int)
|
||||
|
||||
for _, row := range rows {
|
||||
// collection schema name need not be same, since receiver could have other names
|
||||
@ -844,8 +852,12 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
if partialUpdate {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name)
|
||||
}
|
||||
fieldLen[field.Name] += 1
|
||||
switch field.DataType {
|
||||
case schemapb.DataType_Bool:
|
||||
nameColumns[field.Name] = append(nameColumns[field.Name].([]bool), candi.v.Interface().(bool))
|
||||
@ -923,6 +935,22 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool,
|
||||
}
|
||||
columns := make([]*schemapb.FieldData, 0, len(nameColumns))
|
||||
for name, column := range nameColumns {
|
||||
if fieldLen[name] == 0 && partialUpdate {
|
||||
// for partial update, skip update for nullable field
|
||||
// cause we cannot distinguish between missing fields and fields explicitly set to null
|
||||
log.Info("skip empty field for partial update",
|
||||
zap.String("fieldName", name))
|
||||
continue
|
||||
}
|
||||
if fieldLen[name] != rowsLen && partialUpdate {
|
||||
// for partial update, if try to update different field in different rows, return error
|
||||
log.Info("field len is not equal to rows len",
|
||||
zap.String("fieldName", name),
|
||||
zap.Int("fieldLen", fieldLen[name]),
|
||||
zap.Int("rowsLen", rowsLen))
|
||||
return nil, fmt.Errorf("column %s has length %d, expected %d", name, fieldLen[name], rowsLen)
|
||||
}
|
||||
|
||||
colData := fieldData[name]
|
||||
switch colData.Type {
|
||||
case schemapb.DataType_Bool:
|
||||
|
||||
@ -604,12 +604,12 @@ func TestAnyToColumns(t *testing.T) {
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, false, true)
|
||||
var err error
|
||||
err, req.Data, _ = checkAndSetData(body, coll)
|
||||
err, req.Data, _ = checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, int64(0), req.Data[0]["id"])
|
||||
assert.Equal(t, int64(1), req.Data[0]["book_id"])
|
||||
assert.Equal(t, int64(2), req.Data[0]["word_count"])
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, true)
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, fieldsData[len(fieldsData)-1].IsDynamic)
|
||||
assert.Equal(t, schemapb.DataType_JSON, fieldsData[len(fieldsData)-1].Type)
|
||||
@ -621,12 +621,12 @@ func TestAnyToColumns(t *testing.T) {
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, false, true)
|
||||
var err error
|
||||
err, req.Data, _ = checkAndSetData(body, coll)
|
||||
err, req.Data, _ = checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, int64(0), req.Data[0]["id"])
|
||||
assert.Equal(t, int64(1), req.Data[0]["book_id"])
|
||||
assert.Equal(t, int64(2), req.Data[0]["word_count"])
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, false)
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, false, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, true, fieldsData[len(fieldsData)-1].IsDynamic)
|
||||
assert.Equal(t, schemapb.DataType_JSON, fieldsData[len(fieldsData)-1].Type)
|
||||
@ -638,12 +638,12 @@ func TestAnyToColumns(t *testing.T) {
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, true, true)
|
||||
var err error
|
||||
err, req.Data, _ = checkAndSetData(body, coll)
|
||||
err, req.Data, _ = checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, int64(0), req.Data[0]["id"])
|
||||
assert.Equal(t, int64(1), req.Data[0]["book_id"])
|
||||
assert.Equal(t, int64(2), req.Data[0]["word_count"])
|
||||
_, err = anyToColumns(req.Data, nil, coll, true)
|
||||
_, err = anyToColumns(req.Data, nil, coll, true, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "no need to pass pk field"))
|
||||
})
|
||||
@ -652,7 +652,7 @@ func TestAnyToColumns(t *testing.T) {
|
||||
body := []byte("{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2, \"classified\": false, \"databaseID\": null}}")
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, true, false)
|
||||
var err error
|
||||
err, _, _ = checkAndSetData(body, coll)
|
||||
err, _, _ = checkAndSetData(body, coll, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "has pass more fiel"))
|
||||
})
|
||||
@ -662,12 +662,12 @@ func TestAnyToColumns(t *testing.T) {
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, false, false)
|
||||
var err error
|
||||
err, req.Data, _ = checkAndSetData(body, coll)
|
||||
err, req.Data, _ = checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, int64(1), req.Data[0]["book_id"])
|
||||
assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"])
|
||||
assert.Equal(t, int64(2), req.Data[0]["word_count"])
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, true)
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 3, len(fieldsData))
|
||||
assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic)
|
||||
@ -677,7 +677,7 @@ func TestAnyToColumns(t *testing.T) {
|
||||
body := []byte("{\"data\": { \"book_intro\": [0.1, 0.2], \"word_count\": 2}}")
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, false, false)
|
||||
var err error
|
||||
err, _, _ = checkAndSetData(body, coll)
|
||||
err, _, _ = checkAndSetData(body, coll, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "strconv.ParseInt: parsing \"\": invalid syntax"))
|
||||
})
|
||||
@ -687,11 +687,11 @@ func TestAnyToColumns(t *testing.T) {
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, true, false)
|
||||
var err error
|
||||
err, req.Data, _ = checkAndSetData(body, coll)
|
||||
err, req.Data, _ = checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"])
|
||||
assert.Equal(t, int64(2), req.Data[0]["word_count"])
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, true)
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 2, len(fieldsData))
|
||||
assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic)
|
||||
@ -702,12 +702,12 @@ func TestAnyToColumns(t *testing.T) {
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, true, false)
|
||||
var err error
|
||||
err, req.Data, _ = checkAndSetData(body, coll)
|
||||
err, req.Data, _ = checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, int64(1), req.Data[0]["book_id"])
|
||||
assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"])
|
||||
assert.Equal(t, int64(2), req.Data[0]["word_count"])
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, false)
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, false, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 3, len(fieldsData))
|
||||
assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic)
|
||||
@ -718,16 +718,117 @@ func TestAnyToColumns(t *testing.T) {
|
||||
req := InsertReq{}
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, true, false)
|
||||
var err error
|
||||
err, req.Data, _ = checkAndSetData(body, coll)
|
||||
err, req.Data, _ = checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, int64(1), req.Data[0]["book_id"])
|
||||
assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"])
|
||||
assert.Equal(t, int64(2), req.Data[0]["word_count"])
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, false)
|
||||
fieldsData, err := anyToColumns(req.Data, nil, coll, false, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 3, len(fieldsData))
|
||||
assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic)
|
||||
})
|
||||
|
||||
t.Run("partial update with inconsistent fields should fail", func(t *testing.T) {
|
||||
// Create a simple schema with two fields: a and b
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "a",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "b",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
EnableDynamicField: false,
|
||||
}
|
||||
|
||||
// Create two rows: first row updates only field 'a', second row updates only field 'b'
|
||||
rows := []map[string]interface{}{
|
||||
{
|
||||
"id": int64(1),
|
||||
"a": int64(100), // Only field 'a' is provided
|
||||
},
|
||||
{
|
||||
"id": int64(2),
|
||||
"b": int64(200), // Only field 'b' is provided
|
||||
},
|
||||
}
|
||||
|
||||
// Test with partial update = true, this should fail
|
||||
// because different rows are updating different fields
|
||||
_, err := anyToColumns(rows, nil, schema, false, true)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "has length 1, expected 2")
|
||||
})
|
||||
|
||||
t.Run("partial update with consistent missing fields should succeed", func(t *testing.T) {
|
||||
// Create a simple schema with two fields: a and b
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
AutoID: false,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "a",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "b",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
Nullable: true, // Make field 'b' nullable
|
||||
},
|
||||
},
|
||||
EnableDynamicField: false,
|
||||
}
|
||||
|
||||
// Create two rows: both rows update only field 'a', field 'b' is missing in both
|
||||
rows := []map[string]interface{}{
|
||||
{
|
||||
"id": int64(1),
|
||||
"a": int64(100), // Only field 'a' is provided
|
||||
},
|
||||
{
|
||||
"id": int64(2),
|
||||
"a": int64(200), // Only field 'a' is provided
|
||||
},
|
||||
}
|
||||
|
||||
// Test with partial update = true, this should succeed
|
||||
// because the same fields are being updated in all rows
|
||||
fieldsData, err := anyToColumns(rows, nil, schema, false, true)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, fieldsData)
|
||||
|
||||
// Should have id and a fields, but not b (since it's not provided and nullable)
|
||||
fieldNames := make(map[string]bool)
|
||||
for _, fd := range fieldsData {
|
||||
fieldNames[fd.FieldName] = true
|
||||
}
|
||||
assert.True(t, fieldNames["id"])
|
||||
assert.True(t, fieldNames["a"])
|
||||
// Field 'b' should not be present since it wasn't provided in any row
|
||||
assert.False(t, fieldNames["b"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckAndSetData(t *testing.T) {
|
||||
@ -735,7 +836,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
body := []byte("{\"data\": {\"id\": 0,\"$meta\": 2,\"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2, \"classified\": false, \"databaseID\": null}}")
|
||||
coll := generateCollectionSchema(schemapb.DataType_Int64, false, true)
|
||||
var err error
|
||||
err, _, _ = checkAndSetData(body, coll)
|
||||
err, _, _ = checkAndSetData(body, coll, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "use the invalid field name"))
|
||||
})
|
||||
@ -759,7 +860,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
primaryField, floatVectorField,
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
})
|
||||
}, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field"))
|
||||
err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{
|
||||
@ -768,7 +869,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
primaryField, binaryVectorField,
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
})
|
||||
}, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field"))
|
||||
err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{
|
||||
@ -777,7 +878,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
primaryField, float16VectorField,
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
})
|
||||
}, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field"))
|
||||
err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{
|
||||
@ -786,7 +887,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
primaryField, bfloat16VectorField,
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
})
|
||||
}, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field"))
|
||||
err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{
|
||||
@ -795,7 +896,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
primaryField, int8VectorField,
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
})
|
||||
}, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field"))
|
||||
})
|
||||
@ -809,7 +910,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int64,
|
||||
})
|
||||
err, data, validData := checkAndSetData(body, coll)
|
||||
err, data, validData := checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, len(data))
|
||||
assert.Equal(t, 0, len(validData))
|
||||
@ -824,7 +925,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int64,
|
||||
})
|
||||
err, data, validData := checkAndSetData(body, coll)
|
||||
err, data, validData := checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, len(data))
|
||||
assert.Equal(t, 0, len(validData))
|
||||
@ -839,7 +940,7 @@ func TestCheckAndSetData(t *testing.T) {
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int64,
|
||||
})
|
||||
err, data, validData := checkAndSetData(body, coll)
|
||||
err, data, validData := checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, len(data))
|
||||
assert.Equal(t, 0, len(validData))
|
||||
@ -855,7 +956,7 @@ func TestInsertWithInt64(t *testing.T) {
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int64,
|
||||
})
|
||||
err, data, validData := checkAndSetData(body, coll)
|
||||
err, data, validData := checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 1, len(data))
|
||||
assert.Equal(t, 0, len(validData))
|
||||
@ -878,7 +979,7 @@ func TestInsertWithNullableField(t *testing.T) {
|
||||
Nullable: true,
|
||||
})
|
||||
body := []byte("{\"data\": [{\"book_id\": 9999999999999999, \"\nullable\": null,\"book_intro\": [0.1, 0.2], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]},{\"book_id\": 1, \"nullable\": 1,\"book_intro\": [0.3, 0.4], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]}]")
|
||||
err, data, validData := checkAndSetData(body, coll)
|
||||
err, data, validData := checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 2, len(data))
|
||||
assert.Equal(t, 1, len(validData))
|
||||
@ -891,7 +992,7 @@ func TestInsertWithNullableField(t *testing.T) {
|
||||
assert.Equal(t, 4, len(data[0]))
|
||||
assert.Equal(t, 5, len(data[1]))
|
||||
|
||||
fieldData, err := anyToColumns(data, validData, coll, true)
|
||||
fieldData, err := anyToColumns(data, validData, coll, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, len(coll.Fields), len(fieldData))
|
||||
}
|
||||
@ -914,7 +1015,7 @@ func TestInsertWithDefaultValueField(t *testing.T) {
|
||||
},
|
||||
})
|
||||
body := []byte("{\"data\": [{\"book_id\": 9999999999999999, \"\fid\": null,\"book_intro\": [0.1, 0.2], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]},{\"book_id\": 1, \"fid\": 1,\"book_intro\": [0.3, 0.4], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]}]")
|
||||
err, data, validData := checkAndSetData(body, coll)
|
||||
err, data, validData := checkAndSetData(body, coll, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 2, len(data))
|
||||
assert.Equal(t, 1, len(validData))
|
||||
@ -927,7 +1028,7 @@ func TestInsertWithDefaultValueField(t *testing.T) {
|
||||
assert.Equal(t, 4, len(data[0]))
|
||||
assert.Equal(t, 5, len(data[1]))
|
||||
|
||||
fieldData, err := anyToColumns(data, validData, coll, true)
|
||||
fieldData, err := anyToColumns(data, validData, coll, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, len(coll.Fields), len(fieldData))
|
||||
}
|
||||
@ -2078,21 +2179,21 @@ func newRowsWithArray(results []map[string]interface{}) []map[string]interface{}
|
||||
func TestArray(t *testing.T) {
|
||||
body, _ := generateRequestBody(schemapb.DataType_Int64)
|
||||
collectionSchema := generateCollectionSchema(schemapb.DataType_Int64, false, true)
|
||||
err, rows, validRows := checkAndSetData(body, collectionSchema)
|
||||
err, rows, validRows := checkAndSetData(body, collectionSchema, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 0, len(validRows))
|
||||
assert.Equal(t, true, compareRows(rows, generateRawRows(schemapb.DataType_Int64), compareRow))
|
||||
data, err := anyToColumns(rows, validRows, collectionSchema, true)
|
||||
data, err := anyToColumns(rows, validRows, collectionSchema, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, len(collectionSchema.Fields), len(data))
|
||||
|
||||
body, _ = generateRequestBodyWithArray(schemapb.DataType_Int64)
|
||||
collectionSchema = newCollectionSchemaWithArray(generateCollectionSchema(schemapb.DataType_Int64, false, true))
|
||||
err, rows, validRows = checkAndSetData(body, collectionSchema)
|
||||
err, rows, validRows = checkAndSetData(body, collectionSchema, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 0, len(validRows))
|
||||
assert.Equal(t, true, compareRows(rows, newRowsWithArray(generateRawRows(schemapb.DataType_Int64)), compareRow))
|
||||
data, err = anyToColumns(rows, validRows, collectionSchema, true)
|
||||
data, err = anyToColumns(rows, validRows, collectionSchema, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, len(collectionSchema.Fields), len(data))
|
||||
}
|
||||
@ -2175,7 +2276,7 @@ func TestVector(t *testing.T) {
|
||||
},
|
||||
EnableDynamicField: true,
|
||||
}
|
||||
err, rows, validRows := checkAndSetData(body, collectionSchema)
|
||||
err, rows, validRows := checkAndSetData(body, collectionSchema, false)
|
||||
assert.Equal(t, nil, err)
|
||||
for i, row := range rows {
|
||||
assert.Equal(t, 2, len(row[floatVector].([]float32)))
|
||||
@ -2200,7 +2301,7 @@ func TestVector(t *testing.T) {
|
||||
assert.Equal(t, 16, len(row[sparseFloatVector].([]byte)))
|
||||
}
|
||||
assert.Equal(t, 0, len(validRows))
|
||||
data, err := anyToColumns(rows, validRows, collectionSchema, true)
|
||||
data, err := anyToColumns(rows, validRows, collectionSchema, true, false)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, len(collectionSchema.Fields)+1, len(data))
|
||||
|
||||
@ -2211,7 +2312,7 @@ func TestVector(t *testing.T) {
|
||||
}
|
||||
row[field] = value
|
||||
body, _ = wrapRequestBody([]map[string]interface{}{row})
|
||||
err, _, _ = checkAndSetData(body, collectionSchema)
|
||||
err, _, _ = checkAndSetData(body, collectionSchema, false)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
|
||||
@ -7650,7 +7650,8 @@ func (_c *MockProxy_Upsert_Call) RunAndReturn(run func(context.Context, *milvusp
|
||||
func NewMockProxy(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockProxy {
|
||||
},
|
||||
) *MockProxy {
|
||||
mock := &MockProxy{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
|
||||
@ -2845,10 +2845,11 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("role", typeutil.ProxyRole),
|
||||
zap.String("db", request.DbName),
|
||||
zap.String("collection", request.CollectionName),
|
||||
zap.String("partition", request.PartitionName),
|
||||
zap.Uint32("NumRows", request.NumRows),
|
||||
zap.String("db", request.GetDbName()),
|
||||
zap.String("collection", request.GetCollectionName()),
|
||||
zap.String("partition", request.GetPartitionName()),
|
||||
zap.Uint32("NumRows", request.GetNumRows()),
|
||||
zap.Bool("partialUpdate", request.GetPartialUpdate()),
|
||||
)
|
||||
log.Debug("Start processing upsert request in Proxy")
|
||||
|
||||
@ -2890,6 +2891,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
||||
idAllocator: node.rowIDAllocator,
|
||||
chMgr: node.chMgr,
|
||||
schemaTimestamp: request.SchemaTimestamp,
|
||||
node: node,
|
||||
}
|
||||
|
||||
log.Debug("Enqueue upsert request in Proxy",
|
||||
|
||||
@ -17,6 +17,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -28,11 +29,14 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
@ -66,6 +70,12 @@ type upsertTask struct {
|
||||
// delete task need use the oldIDs
|
||||
oldIDs *schemapb.IDs
|
||||
schemaTimestamp uint64
|
||||
|
||||
// write after read, generate write part by queryPreExecute
|
||||
node types.ProxyComponent
|
||||
|
||||
deletePKs *schemapb.IDs
|
||||
insertFieldData []*schemapb.FieldData
|
||||
}
|
||||
|
||||
// TraceCtx returns upsertTask context
|
||||
@ -145,6 +155,361 @@ func (it *upsertTask) OnEnqueue() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) {
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", t.req.GetCollectionName()))
|
||||
var err error
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Retrieve,
|
||||
Timestamp: t.BeginTs(),
|
||||
},
|
||||
DbName: t.req.GetDbName(),
|
||||
CollectionName: t.req.GetCollectionName(),
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel_Strong,
|
||||
NotReturnAllMeta: false,
|
||||
OutputFields: []string{"*"},
|
||||
UseDefaultConsistency: false,
|
||||
GuaranteeTimestamp: t.BeginTs(),
|
||||
}
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var partitionIDs []int64
|
||||
if t.partitionKeyMode {
|
||||
// multi entities with same pk and diff partition keys may be hashed to multi physical partitions
|
||||
// if deleteMsg.partitionID = common.InvalidPartition,
|
||||
// all segments with this pk under the collection will have the delete record
|
||||
partitionIDs = []int64{common.AllPartitionsID}
|
||||
queryReq.PartitionNames = []string{}
|
||||
} else {
|
||||
// partition name could be defaultPartitionName or name specified by sdk
|
||||
partName := t.upsertMsg.DeleteMsg.PartitionName
|
||||
if err := validatePartitionTag(partName, true); err != nil {
|
||||
log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
partID, err := globalMetaCache.GetPartitionID(ctx, t.req.GetDbName(), t.req.GetCollectionName(), partName)
|
||||
if err != nil {
|
||||
log.Warn("Failed to get partition id", zap.String("partitionName", partName), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
partitionIDs = []int64{partID}
|
||||
queryReq.PartitionNames = []string{partName}
|
||||
}
|
||||
|
||||
plan := planparserv2.CreateRequeryPlan(pkField, ids)
|
||||
qt := &queryTask{
|
||||
ctx: t.ctx,
|
||||
Condition: NewTaskCondition(t.ctx),
|
||||
RetrieveRequest: &internalpb.RetrieveRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
PartitionIDs: partitionIDs,
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel_Strong,
|
||||
},
|
||||
request: queryReq,
|
||||
plan: plan,
|
||||
mixCoord: t.node.(*Proxy).mixCoord,
|
||||
lb: t.node.(*Proxy).lbPolicy,
|
||||
}
|
||||
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-retrieveByPKs")
|
||||
defer func() {
|
||||
sp.End()
|
||||
}()
|
||||
queryResult, err := t.node.(*Proxy).query(ctx, qt, sp)
|
||||
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return queryResult, err
|
||||
}
|
||||
|
||||
func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
|
||||
|
||||
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema.CollectionSchema)
|
||||
if err != nil {
|
||||
log.Warn("get primary field schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
primaryFieldData, err := typeutil.GetPrimaryFieldData(it.req.GetFieldsData(), primaryFieldSchema)
|
||||
if err != nil {
|
||||
log.Error("get primary field data failed", zap.Error(err))
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldSchema.Name))
|
||||
}
|
||||
|
||||
oldIDs, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
||||
if err != nil {
|
||||
log.Warn("parse primary field data to IDs failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
oldIDSize := typeutil.GetSizeOfIDs(oldIDs)
|
||||
if oldIDSize == 0 {
|
||||
it.deletePKs = &schemapb.IDs{}
|
||||
it.insertFieldData = it.req.GetFieldsData()
|
||||
log.Info("old records not found, just do insert")
|
||||
return nil
|
||||
}
|
||||
|
||||
tr := timerecord.NewTimeRecorder("Proxy-Upsert-retrieveByPKs")
|
||||
// retrieve by primary key to get original field data
|
||||
resp, err := retrieveByPKs(ctx, it, oldIDs, []string{"*"})
|
||||
if err != nil {
|
||||
log.Info("retrieve by primary key failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if len(resp.GetFieldsData()) == 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("retrieve by primary key failed, no data found")
|
||||
}
|
||||
|
||||
existFieldData := resp.GetFieldsData()
|
||||
pkFieldData, err := typeutil.GetPrimaryFieldData(existFieldData, primaryFieldSchema)
|
||||
if err != nil {
|
||||
log.Error("get primary field data failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
existIDs, err := parsePrimaryFieldData2IDs(pkFieldData)
|
||||
if err != nil {
|
||||
log.Info("parse primary field data to ids failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Info("retrieveByPKs cost",
|
||||
zap.Int("resultNum", typeutil.GetSizeOfIDs(existIDs)),
|
||||
zap.Int64("latency", tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
// check whether the primary key is exist in query result
|
||||
idsChecker, err := typeutil.NewIDsChecker(existIDs)
|
||||
if err != nil {
|
||||
log.Info("create primary key checker failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Build mapping from existing primary keys to their positions in query result
|
||||
// This ensures we can correctly locate data even if query results are not in the same order as request
|
||||
existIDsLen := typeutil.GetSizeOfIDs(existIDs)
|
||||
existPKToIndex := make(map[interface{}]int, existIDsLen)
|
||||
for j := 0; j < existIDsLen; j++ {
|
||||
pk := typeutil.GetPK(existIDs, int64(j))
|
||||
existPKToIndex[pk] = j
|
||||
}
|
||||
|
||||
// set field id for user passed field data
|
||||
upsertFieldData := it.upsertMsg.InsertMsg.GetFieldsData()
|
||||
if len(upsertFieldData) == 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("upsert field data is empty")
|
||||
}
|
||||
for _, fieldData := range upsertFieldData {
|
||||
fieldName := fieldData.GetFieldName()
|
||||
if fieldData.GetIsDynamic() {
|
||||
fieldName = "$meta"
|
||||
}
|
||||
fieldID, ok := it.schema.MapFieldID(fieldName)
|
||||
if !ok {
|
||||
log.Info("field not found in schema", zap.Any("field", fieldData))
|
||||
return merr.WrapErrParameterInvalidMsg("field not found in schema")
|
||||
}
|
||||
fieldData.FieldId = fieldID
|
||||
fieldData.FieldName = fieldName
|
||||
}
|
||||
|
||||
lackOfFieldErr := LackOfFieldsDataBySchema(it.schema.CollectionSchema, it.upsertMsg.InsertMsg.GetFieldsData(), false, true)
|
||||
it.deletePKs = &schemapb.IDs{}
|
||||
it.insertFieldData = make([]*schemapb.FieldData, len(existFieldData))
|
||||
for i := 0; i < oldIDSize; i++ {
|
||||
exist, err := idsChecker.Contains(oldIDs, i)
|
||||
if err != nil {
|
||||
log.Info("check primary key exist in query result failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if exist {
|
||||
// treat upsert as update
|
||||
// 1. if pk exist in query result, add it to deletePKs
|
||||
typeutil.AppendIDs(it.deletePKs, oldIDs, i)
|
||||
// 2. construct the field data for update using correct index mapping
|
||||
oldPK := typeutil.GetPK(oldIDs, int64(i))
|
||||
existIndex, ok := existPKToIndex[oldPK]
|
||||
if !ok {
|
||||
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
|
||||
}
|
||||
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex))
|
||||
err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(i))
|
||||
if err != nil {
|
||||
log.Info("update field data failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// treat upsert as insert
|
||||
if lackOfFieldErr != nil {
|
||||
log.Info("check fields data by schema failed", zap.Error(lackOfFieldErr))
|
||||
return lackOfFieldErr
|
||||
}
|
||||
// use field data from upsert request
|
||||
typeutil.AppendFieldData(it.insertFieldData, upsertFieldData, int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
for _, fieldData := range it.insertFieldData {
|
||||
if fieldData.GetIsDynamic() {
|
||||
continue
|
||||
}
|
||||
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName())
|
||||
if err != nil {
|
||||
log.Info("get field schema failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: Since protobuf cannot correctly identify null values, zero values + valid data are used to identify null values,
|
||||
// therefore for field data obtained from query results, if the field is nullable, it needs to be set to empty values
|
||||
if fieldSchema.GetNullable() {
|
||||
if getValidNumber(fieldData.GetValidData()) != len(fieldData.GetValidData()) {
|
||||
err := ResetNullFieldData(fieldData, fieldSchema)
|
||||
if err != nil {
|
||||
log.Info("reset null field data failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
||||
// but query results fields do not set valid data when returning default value fields,
|
||||
// therefore valid data needs to be manually set to true
|
||||
if fieldSchema.GetDefaultValue() != nil {
|
||||
fieldData.ValidData = make([]bool, oldIDSize)
|
||||
for i := range fieldData.ValidData {
|
||||
fieldData.ValidData[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
|
||||
if !fieldSchema.GetNullable() {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch field.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
switch sd := field.GetScalars().GetData().(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
validRowNum := getValidNumber(field.GetValidData())
|
||||
if validRowNum == 0 {
|
||||
sd.BoolData.Data = make([]bool, 0)
|
||||
} else {
|
||||
ret := make([]bool, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.BoolData.Data[i])
|
||||
}
|
||||
}
|
||||
sd.BoolData.Data = ret
|
||||
}
|
||||
|
||||
case *schemapb.ScalarField_IntData:
|
||||
validRowNum := getValidNumber(field.GetValidData())
|
||||
if validRowNum == 0 {
|
||||
sd.IntData.Data = make([]int32, 0)
|
||||
} else {
|
||||
ret := make([]int32, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.IntData.Data[i])
|
||||
}
|
||||
}
|
||||
sd.IntData.Data = ret
|
||||
}
|
||||
|
||||
case *schemapb.ScalarField_LongData:
|
||||
validRowNum := getValidNumber(field.GetValidData())
|
||||
if validRowNum == 0 {
|
||||
sd.LongData.Data = make([]int64, 0)
|
||||
} else {
|
||||
ret := make([]int64, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.LongData.Data[i])
|
||||
}
|
||||
}
|
||||
sd.LongData.Data = ret
|
||||
}
|
||||
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
validRowNum := getValidNumber(field.GetValidData())
|
||||
if validRowNum == 0 {
|
||||
sd.FloatData.Data = make([]float32, 0)
|
||||
} else {
|
||||
ret := make([]float32, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.FloatData.Data[i])
|
||||
}
|
||||
}
|
||||
sd.FloatData.Data = ret
|
||||
}
|
||||
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
validRowNum := getValidNumber(field.GetValidData())
|
||||
if validRowNum == 0 {
|
||||
sd.DoubleData.Data = make([]float64, 0)
|
||||
} else {
|
||||
ret := make([]float64, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.DoubleData.Data[i])
|
||||
}
|
||||
}
|
||||
sd.DoubleData.Data = ret
|
||||
}
|
||||
|
||||
case *schemapb.ScalarField_StringData:
|
||||
validRowNum := getValidNumber(field.GetValidData())
|
||||
if validRowNum == 0 {
|
||||
sd.StringData.Data = make([]string, 0)
|
||||
} else {
|
||||
ret := make([]string, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.StringData.Data[i])
|
||||
}
|
||||
}
|
||||
sd.StringData.Data = ret
|
||||
}
|
||||
|
||||
case *schemapb.ScalarField_JsonData:
|
||||
validRowNum := getValidNumber(field.GetValidData())
|
||||
if validRowNum == 0 {
|
||||
sd.JsonData.Data = make([][]byte, 0)
|
||||
} else {
|
||||
ret := make([][]byte, validRowNum)
|
||||
for i, valid := range field.GetValidData() {
|
||||
if valid {
|
||||
ret = append(ret, sd.JsonData.Data[i])
|
||||
}
|
||||
}
|
||||
sd.JsonData.Data = ret
|
||||
}
|
||||
|
||||
default:
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
||||
}
|
||||
|
||||
default:
|
||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute")
|
||||
defer sp.End()
|
||||
@ -154,8 +519,20 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(it.schema.CollectionSchema)...)
|
||||
// Calculate embedding fields
|
||||
if function.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) {
|
||||
if it.req.PartialUpdate {
|
||||
// remove the old bm25 fields
|
||||
ret := make([]*schemapb.FieldData, 0)
|
||||
for _, fieldData := range it.upsertMsg.InsertMsg.GetFieldsData() {
|
||||
if bm25Fields.Contain(fieldData.GetFieldName()) {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, fieldData)
|
||||
}
|
||||
it.upsertMsg.InsertMsg.FieldsData = ret
|
||||
}
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Proxy-Upsert-insertPreExecute-call-function-udf")
|
||||
defer sp.End()
|
||||
exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema)
|
||||
@ -271,17 +648,19 @@ func (it *upsertTask) deletePreExecute(ctx context.Context) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("collectionName", collName))
|
||||
|
||||
if it.upsertMsg.DeleteMsg.PrimaryKeys == nil {
|
||||
// if primary keys are not set by queryPreExecute, use oldIDs to delete all given records
|
||||
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs
|
||||
}
|
||||
if typeutil.GetSizeOfIDs(it.upsertMsg.DeleteMsg.PrimaryKeys) == 0 {
|
||||
log.Info("deletePKs is empty, skip deleteExecute")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := validateCollectionName(collName); err != nil {
|
||||
log.Info("Invalid collectionName", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collName)
|
||||
if err != nil {
|
||||
log.Info("Failed to get collection id", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.upsertMsg.DeleteMsg.CollectionID = collID
|
||||
it.collectionID = collID
|
||||
|
||||
if it.partitionKeyMode {
|
||||
// multi entities with same pk and diff partition keys may be hashed to multi physical partitions
|
||||
@ -335,11 +714,14 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
||||
return merr.WrapErrCollectionReplicateMode("upsert")
|
||||
}
|
||||
|
||||
// check collection exists
|
||||
collID, err := globalMetaCache.GetCollectionID(context.Background(), it.req.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
log.Warn("fail to get collection id", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.collectionID = collID
|
||||
|
||||
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, it.req.GetDbName(), collectionName, collID)
|
||||
if err != nil {
|
||||
log.Warn("fail to get collection info", zap.Error(err))
|
||||
@ -398,6 +780,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
CollectionName: it.req.CollectionName,
|
||||
CollectionID: it.collectionID,
|
||||
PartitionName: it.req.PartitionName,
|
||||
FieldsData: it.req.FieldsData,
|
||||
NumRows: uint64(it.req.NumRows),
|
||||
@ -413,12 +796,30 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
||||
),
|
||||
DbName: it.req.DbName,
|
||||
CollectionName: it.req.CollectionName,
|
||||
CollectionID: it.collectionID,
|
||||
NumRows: int64(it.req.NumRows),
|
||||
PartitionName: it.req.PartitionName,
|
||||
CollectionID: it.collectionID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// check if num_rows is valid
|
||||
if it.req.NumRows <= 0 {
|
||||
return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0")
|
||||
}
|
||||
|
||||
if it.req.GetPartialUpdate() {
|
||||
err = it.queryPreExecute(ctx)
|
||||
if err != nil {
|
||||
log.Warn("Fail to queryPreExecute", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
// reconstruct upsert msg after queryPreExecute
|
||||
it.upsertMsg.InsertMsg.FieldsData = it.insertFieldData
|
||||
it.upsertMsg.DeleteMsg.PrimaryKeys = it.deletePKs
|
||||
it.upsertMsg.DeleteMsg.NumRows = int64(typeutil.GetSizeOfIDs(it.deletePKs))
|
||||
}
|
||||
|
||||
err = it.insertPreExecute(ctx)
|
||||
if err != nil {
|
||||
log.Warn("Fail to insertPreExecute", zap.Error(err))
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@ -28,14 +29,18 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func TestUpsertTask_CheckAligned(t *testing.T) {
|
||||
@ -156,7 +161,7 @@ func TestUpsertTask_CheckAligned(t *testing.T) {
|
||||
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
|
||||
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
||||
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, nil, err)
|
||||
|
||||
// less int8 data
|
||||
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
|
||||
@ -520,6 +525,7 @@ func TestUpsertTaskForSchemaMismatch(t *testing.T) {
|
||||
ctx: ctx,
|
||||
req: &milvuspb.UpsertRequest{
|
||||
CollectionName: "col-0",
|
||||
NumRows: 10,
|
||||
},
|
||||
schemaTimestamp: 99,
|
||||
}
|
||||
@ -533,3 +539,546 @@ func TestUpsertTaskForSchemaMismatch(t *testing.T) {
|
||||
assert.ErrorIs(t, err, merr.ErrCollectionSchemaMismatch)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to create test updateTask
|
||||
func createTestUpdateTask() *upsertTask {
|
||||
mcClient := &grpcmixcoordclient.Client{}
|
||||
|
||||
upsertTask := &upsertTask{
|
||||
baseTask: baseTask{},
|
||||
Condition: NewTaskCondition(context.Background()),
|
||||
req: &milvuspb.UpsertRequest{
|
||||
DbName: "test_db",
|
||||
CollectionName: "test_collection",
|
||||
PartitionName: "_default",
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldName: "name",
|
||||
FieldId: 102,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldName: "vector",
|
||||
FieldId: 101,
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{Data: make([]float32, 384)}, // 3 * 128
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
NumRows: 3,
|
||||
},
|
||||
ctx: context.Background(),
|
||||
schema: createTestSchema(),
|
||||
collectionID: 1001,
|
||||
node: &Proxy{
|
||||
mixCoord: mcClient,
|
||||
lbPolicy: NewLBPolicyImpl(nil),
|
||||
},
|
||||
}
|
||||
|
||||
return upsertTask
|
||||
}
|
||||
|
||||
// Helper function to create test schema
|
||||
func createTestSchema() *schemaInfo {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "128"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "name",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
}
|
||||
return newSchemaInfo(schema)
|
||||
}
|
||||
|
||||
func TestRetrieveByPKs_Success(t *testing.T) {
|
||||
mockey.PatchConvey("TestRetrieveByPKs_Success", t, func() {
|
||||
// Setup mocks
|
||||
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(validatePartitionTag).Return(nil).Build()
|
||||
|
||||
mockey.Mock((*MetaCache).GetPartitionID).Return(int64(1002), nil).Build()
|
||||
|
||||
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
|
||||
|
||||
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil).Build()
|
||||
|
||||
globalMetaCache = &MetaCache{}
|
||||
mockey.Mock(globalMetaCache.GetPartitionID).Return(int64(1002), nil).Build()
|
||||
|
||||
// Execute test
|
||||
task := createTestUpdateTask()
|
||||
task.partitionKeyMode = false
|
||||
task.upsertMsg = &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
PartitionName: "_default",
|
||||
},
|
||||
},
|
||||
DeleteMsg: &msgstream.DeleteMsg{
|
||||
DeleteRequest: &msgpb.DeleteRequest{
|
||||
PartitionName: "_default",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, result.Status.ErrorCode)
|
||||
assert.Len(t, result.FieldsData, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetrieveByPKs_GetPrimaryFieldSchemaError(t *testing.T) {
|
||||
mockey.PatchConvey("TestRetrieveByPKs_GetPrimaryFieldSchemaError", t, func() {
|
||||
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
|
||||
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "primary field not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRetrieveByPKs_PartitionKeyMode(t *testing.T) {
|
||||
mockey.PatchConvey("TestRetrieveByPKs_PartitionKeyMode", t, func() {
|
||||
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
|
||||
|
||||
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
}, nil).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.partitionKeyMode = true
|
||||
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_queryPreExecute_Success(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_queryPreExecute_Success", t, func() {
|
||||
// Setup mocks
|
||||
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
|
||||
FieldName: "id",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
||||
},
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(typeutil.GetSizeOfIDs).Return(3).Build()
|
||||
|
||||
mockey.Mock(retrieveByPKs).Return(&milvuspb.QueryResults{
|
||||
Status: merr.Success(),
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldName: "name",
|
||||
FieldId: 102,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldName: "vector",
|
||||
FieldId: 101,
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 128,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{Data: make([]float32, 256)}, // 2 * 128
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(typeutil.NewIDsChecker).Return(&typeutil.IDsChecker{}, nil).Build()
|
||||
|
||||
// Execute test
|
||||
task := createTestUpdateTask()
|
||||
task.schema = createTestSchema()
|
||||
task.upsertMsg = &msgstream.UpsertMsg{
|
||||
InsertMsg: &msgstream.InsertMsg{
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "id",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldName: "name",
|
||||
FieldId: 102,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
},
|
||||
{
|
||||
FieldName: "vector",
|
||||
FieldId: 101,
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := task.queryPreExecute(context.Background())
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, task.deletePKs)
|
||||
assert.NotNil(t, task.insertFieldData)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError", t, func() {
|
||||
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
|
||||
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.schema = createTestSchema()
|
||||
|
||||
err := task.queryPreExecute(context.Background())
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "primary field not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError", t, func() {
|
||||
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}, nil).Build()
|
||||
|
||||
expectedErr := merr.WrapErrParameterInvalidMsg("primary field data not found")
|
||||
mockey.Mock(typeutil.GetPrimaryFieldData).Return(nil, expectedErr).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.schema = createTestSchema()
|
||||
|
||||
err := task.queryPreExecute(context.Background())
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "must assign pk when upsert")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_queryPreExecute_EmptyOldIDs(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_queryPreExecute_EmptyOldIDs", t, func() {
|
||||
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
|
||||
FieldName: "id",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{}, nil).Build()
|
||||
|
||||
mockey.Mock(typeutil.GetSizeOfIDs).Return(0).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.schema = createTestSchema()
|
||||
|
||||
err := task.queryPreExecute(context.Background())
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, task.deletePKs)
|
||||
assert.Equal(t, task.req.GetFieldsData(), task.insertFieldData)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_PreExecute_Success(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_PreExecute_Success", t, func() {
|
||||
// Setup mocks
|
||||
globalMetaCache = &MetaCache{}
|
||||
|
||||
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
||||
|
||||
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
||||
|
||||
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
||||
updateTimestamp: 12345,
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
||||
|
||||
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
||||
|
||||
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
||||
name: "_default",
|
||||
}, nil).Build()
|
||||
|
||||
mockey.Mock((*upsertTask).queryPreExecute).Return(nil).Build()
|
||||
|
||||
mockey.Mock((*upsertTask).insertPreExecute).Return(nil).Build()
|
||||
|
||||
mockey.Mock((*upsertTask).deletePreExecute).Return(nil).Build()
|
||||
|
||||
// Execute test
|
||||
task := createTestUpdateTask()
|
||||
task.req.PartialUpdate = true
|
||||
|
||||
err := task.PreExecute(context.Background())
|
||||
|
||||
// Verify results
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, task.result)
|
||||
assert.Equal(t, int64(1001), task.collectionID)
|
||||
assert.NotNil(t, task.schema)
|
||||
assert.NotNil(t, task.upsertMsg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_PreExecute_ReplicateIDError(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_PreExecute_ReplicateIDError", t, func() {
|
||||
globalMetaCache = &MetaCache{}
|
||||
|
||||
mockey.Mock(GetReplicateID).Return("replica1", nil).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
|
||||
err := task.PreExecute(context.Background())
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "can't operate on the collection under standby mode")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_PreExecute_GetCollectionIDError(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_PreExecute_GetCollectionIDError", t, func() {
|
||||
globalMetaCache = &MetaCache{}
|
||||
|
||||
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
||||
|
||||
expectedErr := merr.WrapErrCollectionNotFound("test_collection")
|
||||
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(0), expectedErr).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
|
||||
err := task.PreExecute(context.Background())
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_PreExecute_PartitionKeyModeError(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_PreExecute_PartitionKeyModeError", t, func() {
|
||||
globalMetaCache = &MetaCache{}
|
||||
|
||||
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
||||
updateTimestamp: 12345,
|
||||
}, nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
||||
|
||||
mockey.Mock(isPartitionKeyMode).Return(true, nil).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.req.PartitionName = "custom_partition" // This should cause error in partition key mode
|
||||
|
||||
err := task.PreExecute(context.Background())
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not support manually specifying the partition names if partition key mode is used")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_PreExecute_InvalidNumRows(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_PreExecute_InvalidNumRows", t, func() {
|
||||
globalMetaCache = &MetaCache{}
|
||||
|
||||
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
||||
updateTimestamp: 12345,
|
||||
}, nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
||||
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
||||
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
||||
name: "_default",
|
||||
}, nil).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.req.NumRows = 0 // Invalid num_rows
|
||||
|
||||
err := task.PreExecute(context.Background())
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid num_rows")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateTask_PreExecute_QueryPreExecuteError(t *testing.T) {
|
||||
mockey.PatchConvey("TestUpdateTask_PreExecute_QueryPreExecuteError", t, func() {
|
||||
globalMetaCache = &MetaCache{}
|
||||
|
||||
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
||||
updateTimestamp: 12345,
|
||||
}, nil).Build()
|
||||
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
||||
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
||||
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
||||
name: "_default",
|
||||
}, nil).Build()
|
||||
|
||||
expectedErr := merr.WrapErrParameterInvalidMsg("query pre-execute failed")
|
||||
mockey.Mock((*upsertTask).queryPreExecute).Return(expectedErr).Build()
|
||||
|
||||
task := createTestUpdateTask()
|
||||
task.req.PartialUpdate = true
|
||||
|
||||
err := task.PreExecute(context.Background())
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "query pre-execute failed")
|
||||
})
|
||||
}
|
||||
|
||||
@ -1934,6 +1934,38 @@ func checkPrimaryFieldData(allFields []*schemapb.FieldSchema, schema *schemapb.C
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// check whether insertMsg has all fields in schema
|
||||
func LackOfFieldsDataBySchema(schema *schemapb.CollectionSchema, fieldsData []*schemapb.FieldData, skipPkFieldCheck bool, skipDynamicFieldCheck bool) error {
|
||||
log := log.With(zap.String("collection", schema.GetName()))
|
||||
|
||||
// find bm25 generated fields
|
||||
bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(schema)...)
|
||||
dataNameMap := make(map[string]*schemapb.FieldData)
|
||||
for _, data := range fieldsData {
|
||||
dataNameMap[data.GetFieldName()] = data
|
||||
}
|
||||
|
||||
for _, fieldSchema := range schema.Fields {
|
||||
if bm25Fields.Contain(fieldSchema.GetName()) {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := dataNameMap[fieldSchema.GetName()]; !ok {
|
||||
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && skipPkFieldCheck) ||
|
||||
IsBM25FunctionOutputField(fieldSchema, schema) ||
|
||||
(skipDynamicFieldCheck && fieldSchema.GetIsDynamic()) {
|
||||
// autoGenField
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName()))
|
||||
return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// for some varchar with analzyer
|
||||
// we need check char format before insert it to message queue
|
||||
// now only support utf-8
|
||||
@ -2490,6 +2522,14 @@ func IsBM25FunctionOutputField(field *schemapb.FieldSchema, collSchema *schemapb
|
||||
return false
|
||||
}
|
||||
|
||||
func GetFunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
|
||||
fields := make([]string, 0)
|
||||
for _, fSchema := range collSchema.Functions {
|
||||
fields = append(fields, fSchema.OutputFieldNames...)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
|
||||
properties := make(map[string]string)
|
||||
for _, pair := range pairs {
|
||||
|
||||
@ -216,7 +216,6 @@ replace (
|
||||
github.com/go-kit/kit => github.com/go-kit/kit v0.1.0
|
||||
github.com/golang-jwt/jwt => github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
github.com/ianlancetaylor/cgosymbolizer => github.com/milvus-io/cgosymbolizer v0.0.0-20250318084424-114f4050c3a6
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.0-rc.1.0.20250716031043-88051c3893ce => /home/zhicheng/SecWorkspace/milvus-proto/go-api
|
||||
github.com/streamnative/pulsarctl => github.com/xiaofan-luan/pulsarctl v0.5.1
|
||||
github.com/tecbot/gorocksdb => github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b // indirect
|
||||
)
|
||||
|
||||
212
pkg/util/typeutil/ids_checker.go
Normal file
212
pkg/util/typeutil/ids_checker.go
Normal file
@ -0,0 +1,212 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
// IDsChecker provides efficient lookup functionality for schema.IDs
|
||||
// It supports checking if an ID at a specific position in one IDs exists in another IDs
|
||||
type IDsChecker struct {
|
||||
intIDSet map[int64]struct{}
|
||||
strIDSet map[string]struct{}
|
||||
idType schemapb.DataType
|
||||
}
|
||||
|
||||
// NewIDsChecker creates a new IDsChecker from the given IDs
|
||||
// The checker will build an internal set for fast lookup
|
||||
func NewIDsChecker(ids *schemapb.IDs) (*IDsChecker, error) {
|
||||
if ids == nil || ids.GetIdField() == nil {
|
||||
return &IDsChecker{}, nil
|
||||
}
|
||||
|
||||
checker := &IDsChecker{}
|
||||
|
||||
switch ids.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
checker.idType = schemapb.DataType_Int64
|
||||
checker.intIDSet = make(map[int64]struct{})
|
||||
data := ids.GetIntId().GetData()
|
||||
for _, id := range data {
|
||||
checker.intIDSet[id] = struct{}{}
|
||||
}
|
||||
case *schemapb.IDs_StrId:
|
||||
checker.idType = schemapb.DataType_VarChar
|
||||
checker.strIDSet = make(map[string]struct{})
|
||||
data := ids.GetStrId().GetData()
|
||||
for _, id := range data {
|
||||
checker.strIDSet[id] = struct{}{}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported ID type in IDs")
|
||||
}
|
||||
|
||||
return checker, nil
|
||||
}
|
||||
|
||||
// Contains checks if the ID at the specified cursor position in idsA exists in this checker
|
||||
// Returns true if the ID exists, false otherwise
|
||||
// Returns error if cursor is out of bounds or type mismatch
|
||||
func (c *IDsChecker) Contains(idsA *schemapb.IDs, cursor int) (bool, error) {
|
||||
if idsA == nil || idsA.GetIdField() == nil {
|
||||
return false, fmt.Errorf("idsA is nil or empty")
|
||||
}
|
||||
|
||||
// Check if cursor is within bounds
|
||||
size := GetSizeOfIDs(idsA)
|
||||
if cursor < 0 || cursor >= size {
|
||||
return false, fmt.Errorf("cursor %d is out of bounds [0, %d)", cursor, size)
|
||||
}
|
||||
|
||||
// If checker is empty, return false for any query
|
||||
if c.IsEmpty() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
switch idsA.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
if c.idType != schemapb.DataType_Int64 {
|
||||
return false, fmt.Errorf("type mismatch: checker expects %v, got Int64", c.idType)
|
||||
}
|
||||
if c.intIDSet == nil {
|
||||
return false, nil
|
||||
}
|
||||
id := idsA.GetIntId().GetData()[cursor]
|
||||
_, exists := c.intIDSet[id]
|
||||
return exists, nil
|
||||
|
||||
case *schemapb.IDs_StrId:
|
||||
if c.idType != schemapb.DataType_VarChar {
|
||||
return false, fmt.Errorf("type mismatch: checker expects %v, got VarChar", c.idType)
|
||||
}
|
||||
if c.strIDSet == nil {
|
||||
return false, nil
|
||||
}
|
||||
id := idsA.GetStrId().GetData()[cursor]
|
||||
_, exists := c.strIDSet[id]
|
||||
return exists, nil
|
||||
|
||||
default:
|
||||
return false, fmt.Errorf("unsupported ID type in idsA")
|
||||
}
|
||||
}
|
||||
|
||||
// ContainsAny checks if any ID in idsA exists in this checker
|
||||
// Returns the indices of IDs that exist in the checker
|
||||
func (c *IDsChecker) ContainsAny(idsA *schemapb.IDs) ([]int, error) {
|
||||
if idsA == nil || idsA.GetIdField() == nil {
|
||||
return nil, fmt.Errorf("idsA is nil or empty")
|
||||
}
|
||||
|
||||
var result []int
|
||||
|
||||
// If checker is empty, return empty result for any query
|
||||
if c.IsEmpty() {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
switch idsA.GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
if c.idType != schemapb.DataType_Int64 {
|
||||
return nil, fmt.Errorf("type mismatch: checker expects %v, got Int64", c.idType)
|
||||
}
|
||||
if c.intIDSet == nil {
|
||||
return result, nil
|
||||
}
|
||||
data := idsA.GetIntId().GetData()
|
||||
for i, id := range data {
|
||||
if _, exists := c.intIDSet[id]; exists {
|
||||
result = append(result, i)
|
||||
}
|
||||
}
|
||||
|
||||
case *schemapb.IDs_StrId:
|
||||
if c.idType != schemapb.DataType_VarChar {
|
||||
return nil, fmt.Errorf("type mismatch: checker expects %v, got VarChar", c.idType)
|
||||
}
|
||||
if c.strIDSet == nil {
|
||||
return result, nil
|
||||
}
|
||||
data := idsA.GetStrId().GetData()
|
||||
for i, id := range data {
|
||||
if _, exists := c.strIDSet[id]; exists {
|
||||
result = append(result, i)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported ID type in idsA")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Size returns the number of unique IDs in this checker
|
||||
func (c *IDsChecker) Size() int {
|
||||
switch c.idType {
|
||||
case schemapb.DataType_Int64:
|
||||
if c.intIDSet == nil {
|
||||
return 0
|
||||
}
|
||||
return len(c.intIDSet)
|
||||
case schemapb.DataType_VarChar:
|
||||
if c.strIDSet == nil {
|
||||
return 0
|
||||
}
|
||||
return len(c.strIDSet)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the checker contains no IDs
|
||||
func (c *IDsChecker) IsEmpty() bool {
|
||||
return c.Size() == 0
|
||||
}
|
||||
|
||||
// GetIDType returns the data type of IDs in this checker
|
||||
func (c *IDsChecker) GetIDType() schemapb.DataType {
|
||||
return c.idType
|
||||
}
|
||||
|
||||
// ContainsIDsAtCursors is a batch operation that checks multiple cursor positions at once
|
||||
// Returns a slice of booleans indicating whether each cursor position exists in the checker
|
||||
func (c *IDsChecker) ContainsIDsAtCursors(idsA *schemapb.IDs, cursors []int) ([]bool, error) {
|
||||
if idsA == nil || idsA.GetIdField() == nil {
|
||||
return nil, fmt.Errorf("idsA is nil or empty")
|
||||
}
|
||||
|
||||
size := GetSizeOfIDs(idsA)
|
||||
results := make([]bool, len(cursors))
|
||||
|
||||
for i, cursor := range cursors {
|
||||
if cursor < 0 || cursor >= size {
|
||||
return nil, fmt.Errorf("cursor %d is out of bounds [0, %d)", cursor, size)
|
||||
}
|
||||
|
||||
exists, err := c.Contains(idsA, cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results[i] = exists
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
372
pkg/util/typeutil/ids_checker_test.go
Normal file
372
pkg/util/typeutil/ids_checker_test.go
Normal file
@ -0,0 +1,372 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package typeutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
)
|
||||
|
||||
func TestNewIDsChecker(t *testing.T) {
|
||||
t.Run("nil IDs", func(t *testing.T) {
|
||||
checker, err := NewIDsChecker(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, checker.IsEmpty())
|
||||
})
|
||||
|
||||
t.Run("empty IDs", func(t *testing.T) {
|
||||
ids := &schemapb.IDs{}
|
||||
checker, err := NewIDsChecker(ids)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, checker.IsEmpty())
|
||||
})
|
||||
|
||||
t.Run("int64 IDs", func(t *testing.T) {
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4, 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
checker, err := NewIDsChecker(ids)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, checker.IsEmpty())
|
||||
assert.Equal(t, 5, checker.Size())
|
||||
assert.Equal(t, schemapb.DataType_Int64, checker.GetIDType())
|
||||
})
|
||||
|
||||
t.Run("string IDs", func(t *testing.T) {
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c", "d"},
|
||||
},
|
||||
},
|
||||
}
|
||||
checker, err := NewIDsChecker(ids)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, checker.IsEmpty())
|
||||
assert.Equal(t, 4, checker.Size())
|
||||
assert.Equal(t, schemapb.DataType_VarChar, checker.GetIDType())
|
||||
})
|
||||
|
||||
t.Run("duplicate IDs", func(t *testing.T) {
|
||||
ids := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 2, 3, 3, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
checker, err := NewIDsChecker(ids)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, checker.Size()) // Only unique IDs are counted
|
||||
})
|
||||
}
|
||||
|
||||
func TestIDsChecker_Contains(t *testing.T) {
|
||||
// Create checker with int64 IDs
|
||||
checkerIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{10, 20, 30, 40, 50},
|
||||
},
|
||||
},
|
||||
}
|
||||
checker, err := NewIDsChecker(checkerIDs)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid int64 lookups", func(t *testing.T) {
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{10, 15, 20, 25, 30},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test each position
|
||||
exists, err := checker.Contains(queryIDs, 0) // 10 - should exist
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, err = checker.Contains(queryIDs, 1) // 15 - should not exist
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
exists, err = checker.Contains(queryIDs, 2) // 20 - should exist
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, err = checker.Contains(queryIDs, 3) // 25 - should not exist
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
exists, err = checker.Contains(queryIDs, 4) // 30 - should exist
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("out of bounds", func(t *testing.T) {
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{10, 20},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := checker.Contains(queryIDs, -1)
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = checker.Contains(queryIDs, 2)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("type mismatch", func(t *testing.T) {
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"a", "b"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := checker.Contains(queryIDs, 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "type mismatch")
|
||||
})
|
||||
|
||||
t.Run("nil query IDs", func(t *testing.T) {
|
||||
_, err := checker.Contains(nil, 0)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIDsChecker_ContainsString(t *testing.T) {
|
||||
// Create checker with string IDs
|
||||
checkerIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"apple", "banana", "cherry", "date"},
|
||||
},
|
||||
},
|
||||
}
|
||||
checker, err := NewIDsChecker(checkerIDs)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("valid string lookups", func(t *testing.T) {
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"apple", "grape", "banana", "kiwi"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
exists, err := checker.Contains(queryIDs, 0) // "apple" - should exist
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, err = checker.Contains(queryIDs, 1) // "grape" - should not exist
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
exists, err = checker.Contains(queryIDs, 2) // "banana" - should exist
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
exists, err = checker.Contains(queryIDs, 3) // "kiwi" - should not exist
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIDsChecker_ContainsAny(t *testing.T) {
|
||||
checkerIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{10, 20, 30, 40, 50},
|
||||
},
|
||||
},
|
||||
}
|
||||
checker, err := NewIDsChecker(checkerIDs)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("some matches", func(t *testing.T) {
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{5, 10, 15, 20, 25, 30},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
indices, err := checker.ContainsAny(queryIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int{1, 3, 5}, indices) // positions of 10, 20, 30
|
||||
})
|
||||
|
||||
t.Run("no matches", func(t *testing.T) {
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
indices, err := checker.ContainsAny(queryIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, indices)
|
||||
})
|
||||
|
||||
t.Run("all matches", func(t *testing.T) {
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{10, 20, 30},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
indices, err := checker.ContainsAny(queryIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []int{0, 1, 2}, indices)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIDsChecker_ContainsIDsAtCursors(t *testing.T) {
|
||||
checkerIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{10, 20, 30, 40, 50},
|
||||
},
|
||||
},
|
||||
}
|
||||
checker, err := NewIDsChecker(checkerIDs)
|
||||
require.NoError(t, err)
|
||||
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{10, 15, 20, 25, 30, 35},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("batch check", func(t *testing.T) {
|
||||
cursors := []int{0, 1, 2, 4}
|
||||
results, err := checker.ContainsIDsAtCursors(queryIDs, cursors)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []bool{true, false, true, true}, results)
|
||||
})
|
||||
|
||||
t.Run("out of bounds in batch", func(t *testing.T) {
|
||||
cursors := []int{0, 1, 10} // 10 is out of bounds
|
||||
_, err := checker.ContainsIDsAtCursors(queryIDs, cursors)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIDsChecker_EmptyChecker(t *testing.T) {
|
||||
checker, err := NewIDsChecker(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Empty checker should return false for any query
|
||||
exists, err := checker.Contains(queryIDs, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
indices, err := checker.ContainsAny(queryIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, indices)
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkIDsChecker_Contains(b *testing.B) {
|
||||
// Create a large checker
|
||||
data := make([]int64, 10000)
|
||||
for i := range data {
|
||||
data[i] = int64(i * 2) // Even numbers
|
||||
}
|
||||
checkerIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: data},
|
||||
},
|
||||
}
|
||||
checker, _ := NewIDsChecker(checkerIDs)
|
||||
|
||||
// Create query IDs
|
||||
queryData := make([]int64, 1000)
|
||||
for i := range queryData {
|
||||
queryData[i] = int64(i) // Mix of even and odd numbers
|
||||
}
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: queryData},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
checker.Contains(queryIDs, i%1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIDsChecker_ContainsAny(b *testing.B) {
|
||||
// Create a large checker
|
||||
data := make([]int64, 10000)
|
||||
for i := range data {
|
||||
data[i] = int64(i * 2) // Even numbers
|
||||
}
|
||||
checkerIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: data},
|
||||
},
|
||||
}
|
||||
checker, _ := NewIDsChecker(checkerIDs)
|
||||
|
||||
// Create query IDs
|
||||
queryData := make([]int64, 1000)
|
||||
for i := range queryData {
|
||||
queryData[i] = int64(i) // Mix of even and odd numbers
|
||||
}
|
||||
queryIDs := &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{Data: queryData},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
checker.ContainsAny(queryIDs)
|
||||
}
|
||||
}
|
||||
@ -752,31 +752,39 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap
|
||||
|
||||
// AppendFieldData appends fields data of specified index from src to dst
|
||||
func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) {
|
||||
dstMap := make(map[int64]*schemapb.FieldData)
|
||||
for _, fieldData := range dst {
|
||||
if fieldData != nil {
|
||||
dstMap[fieldData.FieldId] = fieldData
|
||||
}
|
||||
}
|
||||
for i, fieldData := range src {
|
||||
if dst[i] == nil {
|
||||
dst[i] = &schemapb.FieldData{
|
||||
dstFieldData, ok := dstMap[fieldData.FieldId]
|
||||
if !ok {
|
||||
dstFieldData = &schemapb.FieldData{
|
||||
Type: fieldData.Type,
|
||||
FieldName: fieldData.FieldName,
|
||||
FieldId: fieldData.FieldId,
|
||||
IsDynamic: fieldData.IsDynamic,
|
||||
}
|
||||
dst[i] = dstFieldData
|
||||
}
|
||||
// assign null data
|
||||
if len(fieldData.GetValidData()) != 0 {
|
||||
if dst[i].ValidData == nil {
|
||||
dst[i].ValidData = make([]bool, 0)
|
||||
if dstFieldData.ValidData == nil {
|
||||
dstFieldData.ValidData = make([]bool, 0)
|
||||
}
|
||||
valid := fieldData.ValidData[idx]
|
||||
dst[i].ValidData = append(dst[i].ValidData, valid)
|
||||
dstFieldData.ValidData = append(dstFieldData.ValidData, valid)
|
||||
}
|
||||
switch fieldType := fieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
if dst[i] == nil || dst[i].GetScalars() == nil {
|
||||
dst[i].Field = &schemapb.FieldData_Scalars{
|
||||
if dstFieldData.GetScalars() == nil {
|
||||
dstFieldData.Field = &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{},
|
||||
}
|
||||
}
|
||||
dstScalar := dst[i].GetScalars()
|
||||
dstScalar := dstFieldData.GetScalars()
|
||||
switch srcScalar := fieldType.Scalars.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
if dstScalar.GetBoolData() == nil {
|
||||
@ -880,19 +888,14 @@ func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int6
|
||||
}
|
||||
case *schemapb.FieldData_Vectors:
|
||||
dim := fieldType.Vectors.Dim
|
||||
if dst[i] == nil || dst[i].GetVectors() == nil {
|
||||
dst[i] = &schemapb.FieldData{
|
||||
Type: fieldData.Type,
|
||||
FieldName: fieldData.FieldName,
|
||||
FieldId: fieldData.FieldId,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
},
|
||||
if dstFieldData.GetVectors() == nil {
|
||||
dstFieldData.Field = &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
},
|
||||
}
|
||||
}
|
||||
dstVector := dst[i].GetVectors()
|
||||
dstVector := dstFieldData.GetVectors()
|
||||
switch srcVector := fieldType.Vectors.Data.(type) {
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
if dstVector.GetBinaryVector() == nil {
|
||||
@ -1052,6 +1055,192 @@ func DeleteFieldData(dst []*schemapb.FieldData) {
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error {
|
||||
// Create a map for quick lookup of update fields by field ID
|
||||
updateFieldMap := make(map[string]*schemapb.FieldData)
|
||||
for _, fieldData := range update {
|
||||
updateFieldMap[fieldData.FieldName] = fieldData
|
||||
}
|
||||
// Iterate through base fields and update if corresponding field exists in update
|
||||
for _, baseFieldData := range base {
|
||||
updateFieldData, exists := updateFieldMap[baseFieldData.FieldName]
|
||||
if !exists {
|
||||
// No update for this field, keep original value
|
||||
continue
|
||||
}
|
||||
|
||||
// Update ValidData if present
|
||||
if len(updateFieldData.GetValidData()) != 0 {
|
||||
if len(baseFieldData.GetValidData()) != 0 {
|
||||
baseFieldData.ValidData[idx] = updateFieldData.ValidData[idx]
|
||||
}
|
||||
|
||||
// update field data to null,only modify valid data
|
||||
if !updateFieldData.ValidData[idx] {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Update field data based on type
|
||||
switch baseFieldType := baseFieldData.Field.(type) {
|
||||
case *schemapb.FieldData_Scalars:
|
||||
updateFieldType := updateFieldData.Field.(*schemapb.FieldData_Scalars)
|
||||
baseScalar := baseFieldType.Scalars
|
||||
updateScalar := updateFieldType.Scalars
|
||||
|
||||
switch baseScalar.Data.(type) {
|
||||
case *schemapb.ScalarField_BoolData:
|
||||
updateData := updateScalar.GetBoolData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetBoolData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
case *schemapb.ScalarField_IntData:
|
||||
updateData := updateScalar.GetIntData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetIntData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
case *schemapb.ScalarField_LongData:
|
||||
updateData := updateScalar.GetLongData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetLongData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
case *schemapb.ScalarField_FloatData:
|
||||
updateData := updateScalar.GetFloatData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetFloatData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
case *schemapb.ScalarField_DoubleData:
|
||||
updateData := updateScalar.GetDoubleData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetDoubleData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
case *schemapb.ScalarField_StringData:
|
||||
updateData := updateScalar.GetStringData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetStringData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
case *schemapb.ScalarField_ArrayData:
|
||||
updateData := updateScalar.GetArrayData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
baseScalar.GetArrayData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
case *schemapb.ScalarField_JsonData:
|
||||
updateData := updateScalar.GetJsonData()
|
||||
baseData := baseScalar.GetJsonData()
|
||||
if updateData != nil && int(idx) < len(updateData.Data) {
|
||||
if baseFieldData.GetIsDynamic() {
|
||||
// dynamic field is a json with only 1 level nested struct,
|
||||
// so we need to unmarshal and iterate updateData's key value, and update the baseData's key value
|
||||
var baseMap map[string]interface{}
|
||||
var updateMap map[string]interface{}
|
||||
// unmarshal base and update
|
||||
if err := json.Unmarshal(baseData.Data[idx], &baseMap); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal base json: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(updateData.Data[idx], &updateMap); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal update json: %v", err)
|
||||
}
|
||||
// merge
|
||||
for k, v := range updateMap {
|
||||
baseMap[k] = v
|
||||
}
|
||||
// marshal back
|
||||
newJSON, err := json.Marshal(baseMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal merged json: %v", err)
|
||||
}
|
||||
baseScalar.GetJsonData().Data[idx] = newJSON
|
||||
} else {
|
||||
baseScalar.GetJsonData().Data[idx] = updateData.Data[idx]
|
||||
}
|
||||
}
|
||||
default:
|
||||
log.Error("Not supported scalar field type", zap.String("field type", baseFieldData.Type.String()))
|
||||
return fmt.Errorf("unsupported scalar field type: %s", baseFieldData.Type.String())
|
||||
}
|
||||
|
||||
case *schemapb.FieldData_Vectors:
|
||||
updateFieldType := updateFieldData.Field.(*schemapb.FieldData_Vectors)
|
||||
baseVector := baseFieldType.Vectors
|
||||
updateVector := updateFieldType.Vectors
|
||||
dim := baseVector.Dim
|
||||
|
||||
switch baseVector.Data.(type) {
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
updateData := updateVector.GetBinaryVector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetBinaryVector()
|
||||
startIdx := idx * (dim / 8)
|
||||
endIdx := (idx + 1) * (dim / 8)
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
updateData := updateVector.GetFloatVector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetFloatVector()
|
||||
startIdx := idx * dim
|
||||
endIdx := (idx + 1) * dim
|
||||
if int(endIdx) <= len(updateData.Data) && int(endIdx) <= len(baseData.Data) {
|
||||
copy(baseData.Data[startIdx:endIdx], updateData.Data[startIdx:endIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_Float16Vector:
|
||||
updateData := updateVector.GetFloat16Vector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetFloat16Vector()
|
||||
startIdx := idx * (dim * 2)
|
||||
endIdx := (idx + 1) * (dim * 2)
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_Bfloat16Vector:
|
||||
updateData := updateVector.GetBfloat16Vector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetBfloat16Vector()
|
||||
startIdx := idx * (dim * 2)
|
||||
endIdx := (idx + 1) * (dim * 2)
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_SparseFloatVector:
|
||||
updateData := updateVector.GetSparseFloatVector()
|
||||
if updateData != nil && int(idx) < len(updateData.Contents) {
|
||||
baseData := baseVector.GetSparseFloatVector()
|
||||
if int(idx) < len(baseData.Contents) {
|
||||
baseData.Contents[idx] = updateData.Contents[idx]
|
||||
// Update dimension if necessary
|
||||
if updateData.Dim > baseData.Dim {
|
||||
baseData.Dim = updateData.Dim
|
||||
}
|
||||
}
|
||||
}
|
||||
case *schemapb.VectorField_Int8Vector:
|
||||
updateData := updateVector.GetInt8Vector()
|
||||
if updateData != nil {
|
||||
baseData := baseVector.GetInt8Vector()
|
||||
startIdx := idx * dim
|
||||
endIdx := (idx + 1) * dim
|
||||
if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) {
|
||||
copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx])
|
||||
}
|
||||
}
|
||||
default:
|
||||
log.Error("Not supported vector field type", zap.String("field type", baseFieldData.Type.String()))
|
||||
return fmt.Errorf("unsupported vector field type: %s", baseFieldData.Type.String())
|
||||
}
|
||||
default:
|
||||
log.Error("Not supported field type", zap.String("field type", baseFieldData.Type.String()))
|
||||
return fmt.Errorf("unsupported field type: %s", baseFieldData.Type.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MergeFieldData appends fields data to dst
|
||||
func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error {
|
||||
fieldID2Data := make(map[int64]*schemapb.FieldData)
|
||||
|
||||
@ -18,6 +18,7 @@ package typeutil
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
@ -2980,3 +2981,443 @@ func TestGetDataIterator(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateFieldData(t *testing.T) {
|
||||
const (
|
||||
Dim = 8
|
||||
BoolFieldName = "BoolField"
|
||||
Int64FieldName = "Int64Field"
|
||||
FloatFieldName = "FloatField"
|
||||
StringFieldName = "StringField"
|
||||
FloatVectorFieldName = "FloatVectorField"
|
||||
BoolFieldID = common.StartOfUserFieldID + 1
|
||||
Int64FieldID = common.StartOfUserFieldID + 2
|
||||
FloatFieldID = common.StartOfUserFieldID + 3
|
||||
StringFieldID = common.StartOfUserFieldID + 4
|
||||
FloatVectorFieldID = common.StartOfUserFieldID + 5
|
||||
)
|
||||
|
||||
t.Run("update scalar fields", func(t *testing.T) {
|
||||
// Create base data
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Bool,
|
||||
FieldName: BoolFieldName,
|
||||
FieldId: BoolFieldID,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{true, false, true, false},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: Int64FieldName,
|
||||
FieldId: Int64FieldID,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldName: StringFieldName,
|
||||
FieldId: StringFieldID,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"a", "b", "c", "d"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create update data (only update some fields)
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Bool,
|
||||
FieldName: BoolFieldName,
|
||||
FieldId: BoolFieldID,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_BoolData{
|
||||
BoolData: &schemapb.BoolArray{
|
||||
Data: []bool{false, true, false, true}, // Updated values
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemapb.DataType_VarChar,
|
||||
FieldName: StringFieldName,
|
||||
FieldId: StringFieldID,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{
|
||||
Data: []string{"x", "y", "z", "w"}, // Updated values
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// Note: Int64Field is not in update, so it should remain unchanged
|
||||
}
|
||||
|
||||
// Update index 1
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check results
|
||||
// Bool field should be updated at index 1
|
||||
assert.Equal(t, true, baseData[0].GetScalars().GetBoolData().Data[1])
|
||||
// Int64 field should remain unchanged at index 1
|
||||
assert.Equal(t, int64(2), baseData[1].GetScalars().GetLongData().Data[1])
|
||||
// String field should be updated at index 1
|
||||
assert.Equal(t, "y", baseData[2].GetScalars().GetStringData().Data[1])
|
||||
|
||||
// Other indices should remain unchanged
|
||||
assert.Equal(t, true, baseData[0].GetScalars().GetBoolData().Data[0])
|
||||
assert.Equal(t, int64(1), baseData[1].GetScalars().GetLongData().Data[0])
|
||||
assert.Equal(t, "a", baseData[2].GetScalars().GetStringData().Data[0])
|
||||
})
|
||||
|
||||
t.Run("update vector fields", func(t *testing.T) {
|
||||
// Create base vector data
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: FloatVectorFieldName,
|
||||
FieldId: FloatVectorFieldID,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: Dim,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // vector 0
|
||||
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, // vector 1
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create update data
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: FloatVectorFieldName,
|
||||
FieldId: FloatVectorFieldID,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: Dim,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{
|
||||
100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0, // vector 0
|
||||
900.0, 1000.0, 1100.0, 1200.0, 1300.0, 1400.0, 1500.0, 1600.0, // vector 1
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Update index 1 (second vector)
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check results
|
||||
vectorData := baseData[0].GetVectors().GetFloatVector().Data
|
||||
|
||||
// First vector should remain unchanged
|
||||
for i := 0; i < Dim; i++ {
|
||||
assert.Equal(t, float32(i+1), vectorData[i])
|
||||
}
|
||||
|
||||
// Second vector should be updated
|
||||
for i := 0; i < Dim; i++ {
|
||||
assert.Equal(t, float32(900+i*100), vectorData[Dim+i])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no update fields", func(t *testing.T) {
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: Int64FieldName,
|
||||
FieldId: Int64FieldID,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Empty update data
|
||||
updateData := []*schemapb.FieldData{}
|
||||
|
||||
// Update should succeed but change nothing
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Data should remain unchanged
|
||||
assert.Equal(t, int64(2), baseData[0].GetScalars().GetLongData().Data[1])
|
||||
})
|
||||
|
||||
t.Run("update with ValidData", func(t *testing.T) {
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: Int64FieldName,
|
||||
FieldId: Int64FieldID,
|
||||
ValidData: []bool{true, true, false, true},
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: Int64FieldName,
|
||||
FieldId: Int64FieldID,
|
||||
ValidData: []bool{false, false, true, false},
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{10, 20, 30, 40},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Update index 1
|
||||
err := UpdateFieldData(baseData, updateData, 1)
|
||||
require.NoError(t, err)
|
||||
err = UpdateFieldData(baseData, updateData, 2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that ValidData was updated
|
||||
assert.Equal(t, false, baseData[0].ValidData[1])
|
||||
// Check that data was updated
|
||||
assert.Equal(t, int64(2), baseData[0].GetScalars().GetLongData().Data[1])
|
||||
assert.Equal(t, int64(30), baseData[0].GetScalars().GetLongData().Data[2])
|
||||
})
|
||||
|
||||
t.Run("update dynamic json field", func(t *testing.T) {
|
||||
// Create base data with dynamic JSON field
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "json_field",
|
||||
FieldId: 1,
|
||||
IsDynamic: true,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{
|
||||
[]byte(`{"key1": "value1", "key2": 123}`),
|
||||
[]byte(`{"key3": true, "key4": 456.789}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create update data
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "json_field",
|
||||
FieldId: 1,
|
||||
IsDynamic: true,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{
|
||||
[]byte(`{"key2": 999, "key5": "new_value"}`),
|
||||
[]byte(`{"key4": 111.222, "key6": false}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test updating first row
|
||||
err := UpdateFieldData(baseData, updateData, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first row was correctly merged
|
||||
firstRow := baseData[0].GetScalars().GetJsonData().Data[0]
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(firstRow, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check merged values
|
||||
assert.Equal(t, "value1", result["key1"])
|
||||
assert.Equal(t, float64(999), result["key2"]) // Updated value
|
||||
assert.Equal(t, "new_value", result["key5"]) // New value
|
||||
|
||||
// Test updating second row
|
||||
err = UpdateFieldData(baseData, updateData, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify second row was correctly merged
|
||||
secondRow := baseData[0].GetScalars().GetJsonData().Data[1]
|
||||
err = json.Unmarshal(secondRow, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check merged values
|
||||
assert.Equal(t, true, result["key3"])
|
||||
assert.Equal(t, float64(111.222), result["key4"]) // Updated value
|
||||
assert.Equal(t, false, result["key6"]) // New value
|
||||
})
|
||||
|
||||
t.Run("update non-dynamic json field", func(t *testing.T) {
|
||||
// Create base data with non-dynamic JSON field
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "json_field",
|
||||
FieldId: 1,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{
|
||||
[]byte(`{"key1": "value1"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create update data
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "json_field",
|
||||
FieldId: 1,
|
||||
IsDynamic: false,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{
|
||||
[]byte(`{"key2": "value2"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test updating
|
||||
err := UpdateFieldData(baseData, updateData, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// For non-dynamic fields, the update should completely replace the old value
|
||||
result := baseData[0].GetScalars().GetJsonData().Data[0]
|
||||
assert.Equal(t, []byte(`{"key2": "value2"}`), result)
|
||||
})
|
||||
|
||||
t.Run("invalid json data", func(t *testing.T) {
|
||||
// Create base data with invalid JSON
|
||||
baseData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "json_field",
|
||||
FieldId: 1,
|
||||
IsDynamic: true,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{
|
||||
[]byte(`invalid json`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updateData := []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_JSON,
|
||||
FieldName: "json_field",
|
||||
FieldId: 1,
|
||||
IsDynamic: true,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_JsonData{
|
||||
JsonData: &schemapb.JSONArray{
|
||||
Data: [][]byte{
|
||||
[]byte(`{"key": "value"}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Test updating with invalid base JSON
|
||||
err := UpdateFieldData(baseData, updateData, 0)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal base json")
|
||||
|
||||
// Create base data with valid JSON but invalid update
|
||||
baseData[0].GetScalars().GetJsonData().Data[0] = []byte(`{"key": "value"}`)
|
||||
updateData[0].GetScalars().GetJsonData().Data[0] = []byte(`invalid json`)
|
||||
|
||||
// Test updating with invalid update JSON
|
||||
err = UpdateFieldData(baseData, updateData, 0)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal update json")
|
||||
})
|
||||
}
|
||||
|
||||
@ -289,6 +289,22 @@ func CheckSearchIteratorResult(ctx context.Context, t *testing.T, itr client.Sea
|
||||
}
|
||||
}
|
||||
|
||||
// check expected columns should be contains in actual columns
|
||||
func CheckPartialResult(t *testing.T, expColumns []column.Column, actualColumns []column.Column) {
|
||||
for _, expColumn := range expColumns {
|
||||
exist := false
|
||||
for _, actualColumn := range actualColumns {
|
||||
if expColumn.Name() == actualColumn.Name() && expColumn.Type() != entity.FieldTypeJSON {
|
||||
exist = true
|
||||
EqualColumn(t, expColumn, actualColumn)
|
||||
}
|
||||
}
|
||||
if !exist {
|
||||
log.Error("CheckQueryResult actualColumns no column", zap.String("name", expColumn.Name()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenColumnDataOption -- create column data --
|
||||
type checkIndexOpt struct {
|
||||
state index.IndexState
|
||||
|
||||
@ -318,8 +318,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr
|
||||
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
|
||||
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
|
||||
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6 h1:qTBOTsZ3OwEXkrHRqPn562ddkDqeToIY6CstLIaVQYs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17 h1:zyrKuc0rwT5xWIFkZr/bFWXXYbYvSBMT3iFITnaR8IE=
|
||||
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
|
||||
github.com/milvus-io/milvus/pkg/v2 v2.0.0-20250319085209-5a6b4e56d59e h1:VCr43pG4efacDbM4au70fh8/5hNTftoWzm1iEumvDWM=
|
||||
|
||||
@ -384,6 +384,10 @@ func GetAllFunctionsOutputFields(schema *entity.Schema) []string {
|
||||
return outputFields
|
||||
}
|
||||
|
||||
func GenColumnDataWithOption(fieldType entity.FieldType, option GenDataOption) column.Column {
|
||||
return GenColumnData(option.nb, fieldType, option)
|
||||
}
|
||||
|
||||
// GenColumnData GenColumnDataOption except dynamic column
|
||||
func GenColumnData(nb int, fieldType entity.FieldType, option GenDataOption) column.Column {
|
||||
dim := option.dim
|
||||
@ -655,6 +659,10 @@ func GenColumnDataWithFp32VecConversion(nb int, fieldType entity.FieldType, opti
|
||||
}
|
||||
}
|
||||
|
||||
func GenDynamicColumnDataWithOption(option GenDataOption) []column.Column {
|
||||
return GenDynamicColumnData(option.start, option.nb)
|
||||
}
|
||||
|
||||
func GenDynamicColumnData(start int, nb int) []column.Column {
|
||||
type ListStruct struct {
|
||||
List []int64 `json:"list" milvus:"name:list"`
|
||||
|
||||
206
tests/go_client/testcases/update_test.go
Normal file
206
tests/go_client/testcases/update_test.go
Normal file
@ -0,0 +1,206 @@
|
||||
package testcases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/milvus-io/milvus/client/v2/column"
|
||||
"github.com/milvus-io/milvus/client/v2/entity"
|
||||
client "github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/tests/go_client/common"
|
||||
hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper"
|
||||
)
|
||||
|
||||
func TestUpdatePartialFields(t *testing.T) {
|
||||
/*
|
||||
1. prepare create -> insert -> index -> load -> query
|
||||
2. partial update existing entities -> data updated -> query and verify
|
||||
*/
|
||||
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
|
||||
// connect
|
||||
mc := hp.CreateDefaultMilvusClient(ctx, t)
|
||||
|
||||
// create -> insert [0, 3000) -> flush -> index -> load
|
||||
// create -> insert -> flush -> index -> load
|
||||
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
|
||||
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption())
|
||||
prepare.FlushData(ctx, t, mc, schema.CollectionName)
|
||||
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
|
||||
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
|
||||
|
||||
genPkWithSingleScalarField := func(option *hp.GenDataOption) ([]column.Column, []column.Column) {
|
||||
log.Info("genPkWithSingleScalarField")
|
||||
columns := make([]column.Column, 0, 2)
|
||||
columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeInt64, *option))
|
||||
columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeFloat, *option))
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
genPkWithSinglVectorField := func(option *hp.GenDataOption) ([]column.Column, []column.Column) {
|
||||
log.Info("genPkWithSinglVectorField")
|
||||
columns := make([]column.Column, 0, 2)
|
||||
columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeInt64, *option))
|
||||
columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeFloatVector, *option))
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
updateNb := 200
|
||||
for _, genColumnsFunc := range []func(*hp.GenDataOption) ([]column.Column, []column.Column){genPkWithSingleScalarField, genPkWithSinglVectorField} {
|
||||
// perform partial update operation for existing entities [0, 200) -> query and verify
|
||||
columns, dynamicColumns := genColumnsFunc(hp.TNewDataOption().TWithNb(updateNb).TWithStart(0))
|
||||
updateRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(columns...).WithColumns(dynamicColumns...).WithPartialUpdate(true))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, updateNb, updateRes.UpsertCount)
|
||||
|
||||
expr := fmt.Sprintf("%s < %d", common.DefaultInt64FieldName, updateNb)
|
||||
resSet, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
common.CheckPartialResult(t, append(columns, hp.MergeColumnsToDynamic(updateNb, dynamicColumns, common.DefaultDynamicFieldName)), resSet.Fields)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartialUpdateDynamicField(t *testing.T) {
|
||||
// enable dynamic field and perform partial update operations on dynamic columns
|
||||
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
|
||||
mc := hp.CreateDefaultMilvusClient(ctx, t)
|
||||
|
||||
// create -> insert [0, 3000) -> flush -> index -> load
|
||||
prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true))
|
||||
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption())
|
||||
prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema))
|
||||
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
|
||||
time.Sleep(time.Second * 4)
|
||||
// verify that dynamic field exists
|
||||
testNb := 10
|
||||
resSet, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s < %d", common.DefaultDynamicNumberField, testNb)).
|
||||
WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
require.Equal(t, testNb, resSet.GetColumn(common.DefaultDynamicFieldName).Len())
|
||||
|
||||
// 1. query and gets empty
|
||||
targetPk := int64(20000)
|
||||
resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)).
|
||||
WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
require.Equal(t, 0, resSet.GetColumn(common.DefaultDynamicFieldName).Len())
|
||||
|
||||
// 2. perform partial update operation for existing pk with dynamic column [a=1]
|
||||
dynamicColumnA := column.NewColumnInt32(common.DefaultDynamicNumberField, []int32{1})
|
||||
vecColumn := hp.GenColumnData(1, entity.FieldTypeFloatVector, *hp.TNewDataOption())
|
||||
pkColumnA := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{targetPk})
|
||||
_, err = mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumnA, dynamicColumnA, vecColumn).WithPartialUpdate(true))
|
||||
common.CheckErr(t, err, true)
|
||||
time.Sleep(time.Second * 4)
|
||||
resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)).
|
||||
WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
require.Equal(t, 1, resSet.GetColumn(common.DefaultDynamicFieldName).Len())
|
||||
common.EqualColumn(t, hp.MergeColumnsToDynamic(1, []column.Column{dynamicColumnA}, common.DefaultDynamicFieldName), resSet.GetColumn(common.DefaultDynamicFieldName))
|
||||
|
||||
// 3. perform partial update operation for existing pk with dynamic column [b=true]
|
||||
dynamicColumnB := column.NewColumnBool(common.DefaultDynamicBoolField, []bool{true})
|
||||
_, err = mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumnA, dynamicColumnB).WithPartialUpdate(true))
|
||||
common.CheckErr(t, err, true)
|
||||
time.Sleep(time.Second * 4)
|
||||
resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)).
|
||||
WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
require.Equal(t, 1, resSet.GetColumn(common.DefaultDynamicFieldName).Len())
|
||||
common.EqualColumn(t, hp.MergeColumnsToDynamic(1, []column.Column{dynamicColumnA, dynamicColumnB}, common.DefaultDynamicFieldName), resSet.GetColumn(common.DefaultDynamicFieldName))
|
||||
|
||||
// 4. perform partial update operation for existing pk with dynamic column [a=2, b=false]
|
||||
dynamicColumnA = column.NewColumnInt32(common.DefaultDynamicNumberField, []int32{2})
|
||||
_, err = mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumnA, dynamicColumnA).WithPartialUpdate(true))
|
||||
common.CheckErr(t, err, true)
|
||||
time.Sleep(time.Second * 4)
|
||||
resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)).
|
||||
WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
require.Equal(t, 1, resSet.GetColumn(common.DefaultDynamicFieldName).Len())
|
||||
common.EqualColumn(t, hp.MergeColumnsToDynamic(1, []column.Column{dynamicColumnA, dynamicColumnB}, common.DefaultDynamicFieldName), resSet.GetColumn(common.DefaultDynamicFieldName))
|
||||
}
|
||||
|
||||
func TestUpdateNullableFieldBehavior(t *testing.T) {
|
||||
/*
|
||||
Test nullable field behavior for Update operation:
|
||||
1. Insert data with nullable field having a value
|
||||
2. Update the same entity without providing the nullable field
|
||||
3. Verify that the nullable field retains its original value (not set to null)
|
||||
*/
|
||||
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
|
||||
mc := hp.CreateDefaultMilvusClient(ctx, t)
|
||||
|
||||
// Create collection with nullable field using custom schema
|
||||
collName := common.GenRandomString("update_nullable", 6)
|
||||
|
||||
// Create fields including nullable field
|
||||
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
|
||||
vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
|
||||
nullableField := entity.NewField().WithName("nullable_varchar").WithDataType(entity.FieldTypeVarChar).WithMaxLength(100).WithNullable(true)
|
||||
|
||||
fields := []*entity.Field{pkField, vecField, nullableField}
|
||||
schema := hp.GenSchema(hp.TNewSchemaOption().TWithName(collName).TWithDescription("test nullable field behavior for update").TWithFields(fields))
|
||||
|
||||
// Create collection using schema
|
||||
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// Cleanup
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Second*10)
|
||||
defer cancel()
|
||||
err := mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
})
|
||||
|
||||
// Insert initial data with nullable field having a value
|
||||
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2, 3})
|
||||
vecColumn := hp.GenColumnData(3, entity.FieldTypeFloatVector, *hp.TNewDataOption())
|
||||
nullableColumn := column.NewColumnVarChar("nullable_varchar", []string{"original_1", "original_2", "original_3"})
|
||||
|
||||
_, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn, nullableColumn))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// Use prepare pattern for remaining operations
|
||||
prepare := &hp.CollectionPrepare{}
|
||||
|
||||
// Flush data
|
||||
prepare.FlushData(ctx, t, mc, collName)
|
||||
|
||||
// Create index for vector field
|
||||
indexParams := hp.TNewIndexParams(schema)
|
||||
prepare.CreateIndex(ctx, t, mc, indexParams)
|
||||
|
||||
// Load collection
|
||||
loadParams := hp.NewLoadParams(collName)
|
||||
prepare.Load(ctx, t, mc, loadParams)
|
||||
|
||||
// Wait for loading to complete
|
||||
time.Sleep(time.Second * 5)
|
||||
|
||||
// Update entities without providing nullable field (should retain original values)
|
||||
updatePkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2})
|
||||
updateVecColumn := hp.GenColumnData(2, entity.FieldTypeFloatVector, *hp.TNewDataOption().TWithStart(100))
|
||||
|
||||
updateRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(updatePkColumn, updateVecColumn).WithPartialUpdate(true))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, 2, updateRes.UpsertCount)
|
||||
|
||||
// Wait for consistency
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
// Query to verify nullable field retains original values
|
||||
resSet, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("%s in [1, 2]", common.DefaultInt64FieldName)).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// Verify results
|
||||
require.Equal(t, 2, resSet.GetColumn("nullable_varchar").Len())
|
||||
nullableResults := resSet.GetColumn("nullable_varchar").(*column.ColumnVarChar).Data()
|
||||
require.Equal(t, "original_1", nullableResults[0])
|
||||
require.Equal(t, "original_2", nullableResults[1])
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
package testcases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
@ -653,3 +654,91 @@ func TestUpsertWithoutLoading(t *testing.T) {
|
||||
func TestUpsertPartitionKeyCollection(t *testing.T) {
|
||||
t.Skip("waiting gen partition key field")
|
||||
}
|
||||
|
||||
func TestUpsertNullableFieldBehavior(t *testing.T) {
|
||||
/*
|
||||
Test nullable field behavior for Upsert operation:
|
||||
1. Insert data with nullable field having a value
|
||||
2. Upsert the same entity without providing the nullable field
|
||||
3. Verify that the nullable field is set to null (upsert replaces all fields)
|
||||
*/
|
||||
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
|
||||
mc := hp.CreateDefaultMilvusClient(ctx, t)
|
||||
|
||||
// Create collection with nullable field using custom schema
|
||||
collName := common.GenRandomString("upsert_nullable", 6)
|
||||
|
||||
// Create fields including nullable field
|
||||
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
|
||||
vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim)
|
||||
nullableField := entity.NewField().WithName("nullable_varchar").WithDataType(entity.FieldTypeVarChar).WithMaxLength(100).WithNullable(true)
|
||||
|
||||
fields := []*entity.Field{pkField, vecField, nullableField}
|
||||
schema := hp.GenSchema(hp.TNewSchemaOption().TWithName(collName).TWithDescription("test nullable field behavior for upsert").TWithFields(fields))
|
||||
|
||||
// Create collection using schema
|
||||
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// Cleanup
|
||||
t.Cleanup(func() {
|
||||
ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Second*10)
|
||||
defer cancel()
|
||||
err := mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
})
|
||||
|
||||
// Insert initial data with nullable field having a value
|
||||
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2, 3})
|
||||
vecColumn := hp.GenColumnData(3, entity.FieldTypeFloatVector, *hp.TNewDataOption())
|
||||
nullableColumn := column.NewColumnVarChar("nullable_varchar", []string{"original_1", "original_2", "original_3"})
|
||||
|
||||
_, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn, nullableColumn))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// Use prepare pattern for remaining operations
|
||||
prepare := &hp.CollectionPrepare{}
|
||||
|
||||
// Flush data
|
||||
prepare.FlushData(ctx, t, mc, collName)
|
||||
|
||||
// Create index for vector field
|
||||
indexParams := hp.TNewIndexParams(schema)
|
||||
prepare.CreateIndex(ctx, t, mc, indexParams)
|
||||
|
||||
// Load collection
|
||||
loadParams := hp.NewLoadParams(collName)
|
||||
prepare.Load(ctx, t, mc, loadParams)
|
||||
|
||||
// Wait for loading to complete
|
||||
time.Sleep(time.Second * 5)
|
||||
|
||||
// Upsert entities without providing nullable field (should set to null)
|
||||
upsertPkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2})
|
||||
upsertVecColumn := hp.GenColumnData(2, entity.FieldTypeFloatVector, *hp.TNewDataOption().TWithStart(100))
|
||||
|
||||
upsertRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(upsertPkColumn, upsertVecColumn))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, 2, upsertRes.UpsertCount)
|
||||
|
||||
// Wait for consistency
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
// Query to verify nullable field is set to null
|
||||
resSet, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("%s in [1, 2]", common.DefaultInt64FieldName)).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// Verify results - nullable field should be null
|
||||
require.Equal(t, 2, resSet.GetColumn("nullable_varchar").Len())
|
||||
nullableResults := resSet.GetColumn("nullable_varchar").(*column.ColumnVarChar).Data()
|
||||
require.Equal(t, "", nullableResults[0]) // null value is represented as empty string
|
||||
require.Equal(t, "", nullableResults[1]) // null value is represented as empty string
|
||||
|
||||
// Query entity that was not upserted to verify original value is preserved
|
||||
resSet3, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("%s == 3", common.DefaultInt64FieldName)).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
require.Equal(t, 1, resSet3.GetColumn("nullable_varchar").Len())
|
||||
nullableResult3 := resSet3.GetColumn("nullable_varchar").(*column.ColumnVarChar).Data()
|
||||
require.Equal(t, "original_3", nullableResult3[0])
|
||||
}
|
||||
|
||||
@ -74,6 +74,10 @@ func (s *MiniClusterSuite) InsertAndFlush(ctx context.Context, dbName, collectio
|
||||
|
||||
func (s *MiniClusterSuite) CreateCollectionWithConfiguration(ctx context.Context, cfg *CreateCollectionConfig) {
|
||||
schema := ConstructSchema(cfg.CollectionName, cfg.Dim, true)
|
||||
s.CreateCollection(ctx, cfg, schema)
|
||||
}
|
||||
|
||||
func (s *MiniClusterSuite) CreateCollection(ctx context.Context, cfg *CreateCollectionConfig, schema *schemapb.CollectionSchema) {
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
s.NotNil(marshaledSchema)
|
||||
|
||||
1351
tests/restful_client_v2/testcases/test_partial_update.py
Normal file
1351
tests/restful_client_v2/testcases/test_partial_update.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user