milvus/internal/proxy/task_upsert.go
marcelo-cjl 3b599441fd
feat: Add nullable vector support for proxy and querynode (#46305)
related: #45993 

This commit extends nullable vector support to the proxy layer,
querynode,
and adds comprehensive validation, search reduce, and field data
handling
    for nullable vectors with sparse storage.
    
    Proxy layer changes:
- Update validate_util.go checkAligned() with getExpectedVectorRows()
helper
      to validate nullable vector field alignment using valid data count
- Update checkFloatVectorFieldData/checkSparseFloatVectorFieldData for
      nullable vector validation with proper row count expectations
- Add FieldDataIdxComputer in typeutil/schema.go for logical-to-physical
      index translation during search reduce operations
- Update search_reduce_util.go reduceSearchResultData to use
idxComputers
      for correct field data indexing with nullable vectors
- Update task.go, task_query.go, task_upsert.go for nullable vector
handling
    - Update msg_pack.go with nullable vector field data processing
    
    QueryNode layer changes:
    - Update segments/result.go for nullable vector result handling
- Update segments/search_reduce.go with nullable vector offset
translation
    
    Storage and index changes:
- Update data_codec.go and utils.go for nullable vector serialization
- Update indexcgowrapper/dataset.go and index.go for nullable vector
indexing
    
    Utility changes:
- Add FieldDataIdxComputer struct with Compute() method for efficient
      logical-to-physical index mapping across multiple field data
- Update EstimateEntitySize() and AppendFieldData() with fieldIdxs
parameter
    - Update funcutil.go with nullable vector support functions

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Full support for nullable vector fields (float, binary, float16,
bfloat16, int8, sparse) across ingest, storage, indexing, search and
retrieval; logical↔physical offset mapping preserves row semantics.
  * Client: compaction control and compaction-state APIs.

* **Bug Fixes**
* Improved validation for adding vector fields (nullable + dimension
checks) and corrected search/query behavior for nullable vectors.

* **Chores**
  * Persisted validity maps with indexes and on-disk formats.

* **Tests**
  * Extensive new and updated end-to-end nullable-vector tests.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
2025-12-24 10:13:19 +08:00

1205 lines
38 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"strconv"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type upsertTask struct {
baseTask
Condition
upsertMsg *msgstream.UpsertMsg
req *milvuspb.UpsertRequest
baseMsg msgstream.BaseMsg
ctx context.Context
timestamps []uint64
rowIDs []int64
result *milvuspb.MutationResult
idAllocator *allocator.IDAllocator
collectionID UniqueID
chMgr channelsMgr
chTicker channelsTimeTicker
vChannels []vChan
pChannels []pChan
schema *schemaInfo
partitionKeyMode bool
partitionKeys *schemapb.FieldData
// automatic generate pk as new pk wehen autoID == true
// delete task need use the oldIDs
oldIDs *schemapb.IDs
schemaTimestamp uint64
// write after read, generate write part by queryPreExecute
node types.ProxyComponent
deletePKs *schemapb.IDs
insertFieldData []*schemapb.FieldData
storageCost segcore.StorageCost
}
// TraceCtx returns upsertTask context
func (it *upsertTask) TraceCtx() context.Context {
return it.ctx
}
func (it *upsertTask) ID() UniqueID {
return it.req.Base.MsgID
}
func (it *upsertTask) SetID(uid UniqueID) {
it.req.Base.MsgID = uid
}
func (it *upsertTask) Name() string {
return UpsertTaskName
}
func (it *upsertTask) Type() commonpb.MsgType {
return it.req.Base.MsgType
}
func (it *upsertTask) BeginTs() Timestamp {
return it.baseMsg.BeginTimestamp
}
func (it *upsertTask) SetTs(ts Timestamp) {
it.baseMsg.BeginTimestamp = ts
it.baseMsg.EndTimestamp = ts
}
func (it *upsertTask) EndTs() Timestamp {
return it.baseMsg.EndTimestamp
}
func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
ret := make(map[pChan]pChanStatistics)
channels := it.getChannels()
beginTs := it.BeginTs()
endTs := it.EndTs()
for _, channel := range channels {
ret[channel] = pChanStatistics{
minTs: beginTs,
maxTs: endTs,
}
}
return ret, nil
}
func (it *upsertTask) setChannels() error {
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.GetDbName(), it.req.CollectionName)
if err != nil {
return err
}
channels, err := it.chMgr.getChannels(collID)
if err != nil {
return err
}
it.pChannels = channels
return nil
}
func (it *upsertTask) getChannels() []pChan {
return it.pChannels
}
func (it *upsertTask) OnEnqueue() error {
if it.req.Base == nil {
it.req.Base = commonpbutil.NewMsgBase()
}
it.req.Base.MsgType = commonpb.MsgType_Upsert
it.req.Base.SourceID = paramtable.GetNodeID()
return nil
}
func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, segcore.StorageCost, error) {
log := log.Ctx(ctx).With(zap.String("collectionName", t.req.GetCollectionName()))
var err error
queryReq := &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
Timestamp: t.BeginTs(),
},
DbName: t.req.GetDbName(),
CollectionName: t.req.GetCollectionName(),
ConsistencyLevel: commonpb.ConsistencyLevel_Strong,
NotReturnAllMeta: false,
OutputFields: []string{"*"},
UseDefaultConsistency: false,
GuaranteeTimestamp: t.BeginTs(),
Namespace: t.req.Namespace,
}
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
if err != nil {
return nil, segcore.StorageCost{}, err
}
var partitionIDs []int64
if t.partitionKeyMode {
// multi entities with same pk and diff partition keys may be hashed to multi physical partitions
// if deleteMsg.partitionID = common.InvalidPartition,
// all segments with this pk under the collection will have the delete record
partitionIDs = []int64{common.AllPartitionsID}
queryReq.PartitionNames = []string{}
} else {
// partition name could be defaultPartitionName or name specified by sdk
partName := t.upsertMsg.DeleteMsg.PartitionName
if err := validatePartitionTag(partName, true); err != nil {
log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
return nil, segcore.StorageCost{}, err
}
partID, err := globalMetaCache.GetPartitionID(ctx, t.req.GetDbName(), t.req.GetCollectionName(), partName)
if err != nil {
log.Warn("Failed to get partition id", zap.String("partitionName", partName), zap.Error(err))
return nil, segcore.StorageCost{}, err
}
partitionIDs = []int64{partID}
queryReq.PartitionNames = []string{partName}
}
plan := planparserv2.CreateRequeryPlan(pkField, ids)
plan.Namespace = t.req.Namespace
qt := &queryTask{
ctx: t.ctx,
Condition: NewTaskCondition(t.ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
PartitionIDs: partitionIDs,
ConsistencyLevel: commonpb.ConsistencyLevel_Strong,
},
request: queryReq,
plan: plan,
mixCoord: t.node.(*Proxy).mixCoord,
lb: t.node.(*Proxy).lbPolicy,
}
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-retrieveByPKs")
defer func() {
sp.End()
}()
queryResult, storageCost, err := t.node.(*Proxy).query(ctx, qt, sp)
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
return nil, storageCost, err
}
return queryResult, storageCost, err
}
func (it *upsertTask) queryPreExecute(ctx context.Context) error {
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema.CollectionSchema)
if err != nil {
log.Warn("get primary field schema failed", zap.Error(err))
return err
}
primaryFieldData, err := typeutil.GetPrimaryFieldData(it.req.GetFieldsData(), primaryFieldSchema)
if err != nil {
log.Error("get primary field data failed", zap.Error(err))
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldSchema.Name))
}
upsertIDs, err := parsePrimaryFieldData2IDs(primaryFieldData)
if err != nil {
log.Warn("parse primary field data to IDs failed", zap.Error(err))
return err
}
upsertIDSize := typeutil.GetSizeOfIDs(upsertIDs)
if upsertIDSize == 0 {
it.deletePKs = &schemapb.IDs{}
it.insertFieldData = it.req.GetFieldsData()
log.Info("old records not found, just do insert")
return nil
}
tr := timerecord.NewTimeRecorder("Proxy-Upsert-retrieveByPKs")
// retrieve by primary key to get original field data
resp, storageCost, err := retrieveByPKs(ctx, it, upsertIDs, []string{"*"})
if err != nil {
log.Info("retrieve by primary key failed", zap.Error(err))
return err
}
it.storageCost = storageCost
if len(resp.GetFieldsData()) == 0 {
return merr.WrapErrParameterInvalidMsg("retrieve by primary key failed, no data found")
}
existFieldData := resp.GetFieldsData()
pkFieldData, err := typeutil.GetPrimaryFieldData(existFieldData, primaryFieldSchema)
if err != nil {
log.Error("get primary field data failed", zap.Error(err))
return err
}
existIDs, err := parsePrimaryFieldData2IDs(pkFieldData)
if err != nil {
log.Info("parse primary field data to ids failed", zap.Error(err))
return err
}
log.Info("retrieveByPKs cost",
zap.Int("resultNum", typeutil.GetSizeOfIDs(existIDs)),
zap.Int64("latency", tr.ElapseSpan().Milliseconds()))
// set field id for user passed field data, prepare for merge logic
if len(it.upsertMsg.InsertMsg.GetFieldsData()) == 0 {
return merr.WrapErrParameterInvalidMsg("upsert field data is empty")
}
for _, fieldData := range it.upsertMsg.InsertMsg.GetFieldsData() {
fieldName := fieldData.GetFieldName()
if fieldData.GetIsDynamic() {
fieldName = "$meta"
}
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldName)
if err != nil {
log.Info("get field schema failed", zap.Error(err))
return err
}
fieldData.FieldId = fieldSchema.GetFieldID()
fieldData.FieldName = fieldName
// compatible with different nullable data format from sdk
if len(fieldData.GetValidData()) != 0 {
err := FillWithNullValue(fieldData, fieldSchema, int(it.upsertMsg.InsertMsg.NRows()))
if err != nil {
log.Info("unify null field data format failed", zap.Error(err))
return err
}
}
}
// Validate field data alignment before processing to prevent index out of range panic
if err := newValidateUtil().checkAligned(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.schemaHelper, uint64(upsertIDSize)); err != nil {
log.Warn("check field data aligned failed", zap.Error(err))
return err
}
// Two nullable data formats are supported:
//
// COMPRESSED FORMAT (SDK format, before validateUtil.fillWithValue processing):
// Logical data: [1, null, 2]
// Storage: Data=[1, 2] + ValidData=[true, false, true]
// - Data array contains only non-null values (compressed)
// - ValidData array tracks null positions for all rows
//
// FULL FORMAT (Milvus internal format, after validateUtil.fillWithValue processing):
// Logical data: [1, null, 2]
// Storage: Data=[1, 0, 2] + ValidData=[true, false, true]
// - Data array contains values for all rows (nulls filled with zero/default)
// - ValidData array still tracks null positions
//
// Note: we will unify the nullable format to FULL FORMAT before executing the merge logic
insertIdxInUpsert := make([]int, 0)
updateIdxInUpsert := make([]int, 0)
// 1. split upsert data into insert and update by query result
idsChecker, err := typeutil.NewIDsChecker(existIDs)
if err != nil {
log.Info("create primary key checker failed", zap.Error(err))
return err
}
for upsertIdx := 0; upsertIdx < upsertIDSize; upsertIdx++ {
exist, err := idsChecker.Contains(upsertIDs, upsertIdx)
if err != nil {
log.Info("check primary key exist in query result failed", zap.Error(err))
return err
}
if exist {
updateIdxInUpsert = append(updateIdxInUpsert, upsertIdx)
} else {
insertIdxInUpsert = append(insertIdxInUpsert, upsertIdx)
}
}
// 2. merge field data on update semantic
it.deletePKs = &schemapb.IDs{}
it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize))
if len(updateIdxInUpsert) > 0 {
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
// but query results fields do not set valid data when returning default value fields,
// therefore valid data needs to be manually set to true
for _, fieldData := range existFieldData {
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName())
if err != nil {
log.Info("get field schema failed", zap.Error(err))
return err
}
if fieldSchema.GetDefaultValue() != nil {
fieldData.ValidData = make([]bool, upsertIDSize)
for i := range fieldData.ValidData {
fieldData.ValidData[i] = true
}
}
}
// Build mapping from existing primary keys to their positions in query result
// This ensures we can correctly locate data even if query results are not in the same order as request
existIDsLen := typeutil.GetSizeOfIDs(existIDs)
existPKToIndex := make(map[interface{}]int, existIDsLen)
for j := 0; j < existIDsLen; j++ {
pk := typeutil.GetPK(existIDs, int64(j))
existPKToIndex[pk] = j
}
baseIdx := 0
idxComputer := typeutil.NewFieldDataIdxComputer(existFieldData)
for _, idx := range updateIdxInUpsert {
typeutil.AppendIDs(it.deletePKs, upsertIDs, idx)
oldPK := typeutil.GetPK(upsertIDs, int64(idx))
existIndex, ok := existPKToIndex[oldPK]
if !ok {
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
}
fieldIdxs := idxComputer.Compute(int64(existIndex))
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex), fieldIdxs...)
err := typeutil.UpdateFieldData(it.insertFieldData, it.upsertMsg.InsertMsg.GetFieldsData(), int64(baseIdx), int64(idx))
baseIdx += 1
if err != nil {
log.Info("update field data failed", zap.Error(err))
return err
}
}
}
// 3. merge field data on insert semantic
if len(insertIdxInUpsert) > 0 {
// if necessary field is not exist in upsert request, return error
lackOfFieldErr := LackOfFieldsDataBySchema(it.schema.CollectionSchema, it.upsertMsg.InsertMsg.GetFieldsData(), false, true)
if lackOfFieldErr != nil {
log.Info("check fields data by schema failed", zap.Error(lackOfFieldErr))
return lackOfFieldErr
}
// if the nullable field has not passed in upsert request, which means the len(upsertFieldData) < len(it.insertFieldData)
// we need to generate the nullable field data before append as insert
insertWithNullField := make([]*schemapb.FieldData, 0)
upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (string, *schemapb.FieldData) {
return field.GetFieldName(), field
})
for _, fieldSchema := range it.schema.CollectionSchema.Fields {
if fieldData, ok := upsertFieldMap[fieldSchema.Name]; !ok {
if fieldSchema.GetNullable() || fieldSchema.GetDefaultValue() != nil {
fieldData, err := GenNullableFieldData(fieldSchema, upsertIDSize)
if err != nil {
log.Info("generate nullable field data failed", zap.Error(err))
return err
}
insertWithNullField = append(insertWithNullField, fieldData)
}
} else {
insertWithNullField = append(insertWithNullField, fieldData)
}
}
vectorIdxMap := make([][]int64, len(insertIdxInUpsert))
for rowIdx, offset := range insertIdxInUpsert {
vectorIdxMap[rowIdx] = make([]int64, len(insertWithNullField))
for fieldIdx := range insertWithNullField {
vectorIdxMap[rowIdx][fieldIdx] = int64(offset)
}
}
for fieldIdx, fieldData := range insertWithNullField {
validData := fieldData.GetValidData()
if len(validData) > 0 && typeutil.IsVectorType(fieldData.Type) {
dataIdx := int64(0)
rowIdx := 0
for i := 0; i < len(validData) && rowIdx < len(insertIdxInUpsert); i++ {
if i == insertIdxInUpsert[rowIdx] {
vectorIdxMap[rowIdx][fieldIdx] = dataIdx
rowIdx++
}
if validData[i] {
dataIdx++
}
}
}
}
for rowIdx, idx := range insertIdxInUpsert {
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx), vectorIdxMap[rowIdx]...)
}
}
for _, fieldData := range it.insertFieldData {
if len(fieldData.GetValidData()) > 0 {
err := ToCompressedFormatNullable(fieldData)
if err != nil {
log.Info("convert to compressed format nullable failed", zap.Error(err))
return err
}
}
}
return nil
}
// ToCompressedFormatNullable converts the field data from full format nullable to compressed format nullable
func ToCompressedFormatNullable(field *schemapb.FieldData) error {
if getValidNumber(field.GetValidData()) == len(field.GetValidData()) {
return nil
}
switch field.Field.(type) {
case *schemapb.FieldData_Scalars:
switch sd := field.GetScalars().GetData().(type) {
case *schemapb.ScalarField_BoolData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.BoolData.Data = make([]bool, 0)
} else {
ret := make([]bool, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.BoolData.Data[i])
}
}
sd.BoolData.Data = ret
}
case *schemapb.ScalarField_IntData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.IntData.Data = make([]int32, 0)
} else {
ret := make([]int32, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.IntData.Data[i])
}
}
sd.IntData.Data = ret
}
case *schemapb.ScalarField_LongData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.LongData.Data = make([]int64, 0)
} else {
ret := make([]int64, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.LongData.Data[i])
}
}
sd.LongData.Data = ret
}
case *schemapb.ScalarField_FloatData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.FloatData.Data = make([]float32, 0)
} else {
ret := make([]float32, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.FloatData.Data[i])
}
}
sd.FloatData.Data = ret
}
case *schemapb.ScalarField_DoubleData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.DoubleData.Data = make([]float64, 0)
} else {
ret := make([]float64, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.DoubleData.Data[i])
}
}
sd.DoubleData.Data = ret
}
case *schemapb.ScalarField_StringData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.StringData.Data = make([]string, 0)
} else {
ret := make([]string, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.StringData.Data[i])
}
}
sd.StringData.Data = ret
}
case *schemapb.ScalarField_JsonData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.JsonData.Data = make([][]byte, 0)
} else {
ret := make([][]byte, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.JsonData.Data[i])
}
}
sd.JsonData.Data = ret
}
case *schemapb.ScalarField_ArrayData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.ArrayData.Data = make([]*schemapb.ScalarField, 0)
} else {
ret := make([]*schemapb.ScalarField, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.ArrayData.Data[i])
}
}
sd.ArrayData.Data = ret
}
case *schemapb.ScalarField_TimestamptzData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.TimestamptzData.Data = make([]int64, 0)
} else {
ret := make([]int64, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.TimestamptzData.Data[i])
}
}
sd.TimestamptzData.Data = ret
}
case *schemapb.ScalarField_GeometryWktData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.GeometryWktData.Data = make([]string, 0)
} else {
ret := make([]string, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.GeometryWktData.Data[i])
}
}
sd.GeometryWktData.Data = ret
}
case *schemapb.ScalarField_GeometryData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.GeometryData.Data = make([][]byte, 0)
} else {
ret := make([][]byte, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.GeometryData.Data[i])
}
}
sd.GeometryData.Data = ret
}
default:
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
}
case *schemapb.FieldData_Vectors:
// Vector data is already in compressed format, skip
return nil
default:
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
}
return nil
}
// GenNullableFieldData generates nullable field data in FULL FORMAT
func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schemapb.FieldData, error) {
switch field.DataType {
case schemapb.DataType_Bool:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_BoolData{
BoolData: &schemapb.BoolArray{
Data: make([]bool, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_Int32:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{
Data: make([]int32, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_Int64:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: make([]int64, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_Float:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{
FloatData: &schemapb.FloatArray{
Data: make([]float32, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_Double:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{
Data: make([]float64, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_VarChar:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: make([]string, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_JSON:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: make([][]byte, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_Array:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: make([]*schemapb.ScalarField, upsertIDSize),
},
},
},
},
}, nil
case schemapb.DataType_Timestamptz:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_TimestamptzData{
TimestamptzData: &schemapb.TimestamptzArray{
Data: make([]int64, upsertIDSize),
},
},
},
},
}, nil
// the intput data of geometry field is in wkt format
case schemapb.DataType_Geometry:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_GeometryWktData{
GeometryWktData: &schemapb.GeometryWktArray{
Data: make([]string, upsertIDSize),
},
},
},
},
}, nil
default:
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined scalar data type:%s", field.DataType.String()))
}
}
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute")
defer sp.End()
collectionName := it.upsertMsg.InsertMsg.CollectionName
if err := validateCollectionName(collectionName); err != nil {
log.Ctx(ctx).Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
bm25Fields := typeutil.NewSet[string](GetBM25FunctionOutputFields(it.schema.CollectionSchema)...)
if it.req.PartialUpdate {
// remove the old bm25 fields
ret := make([]*schemapb.FieldData, 0)
for _, fieldData := range it.upsertMsg.InsertMsg.GetFieldsData() {
if bm25Fields.Contain(fieldData.GetFieldName()) {
continue
}
ret = append(ret, fieldData)
}
it.upsertMsg.InsertMsg.FieldsData = ret
}
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
// set upsertTask.insertRequest.rowIDs
tr := timerecord.NewTimeRecorder("applyPK")
clusterID := Params.CommonCfg.ClusterID.GetAsUint64()
rowIDBegin, rowIDEnd, _ := common.AllocAutoID(it.idAllocator.Alloc, rowNums, clusterID)
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
it.upsertMsg.InsertMsg.RowIDs = make([]UniqueID, rowNums)
it.rowIDs = make([]UniqueID, rowNums)
for i := rowIDBegin; i < rowIDEnd; i++ {
offset := i - rowIDBegin
it.upsertMsg.InsertMsg.RowIDs[offset] = i
it.rowIDs[offset] = i
}
// set upsertTask.insertRequest.timeStamps
rowNum := it.upsertMsg.InsertMsg.NRows()
it.upsertMsg.InsertMsg.Timestamps = make([]uint64, rowNum)
it.timestamps = make([]uint64, rowNum)
for index := range it.timestamps {
it.upsertMsg.InsertMsg.Timestamps[index] = it.BeginTs()
it.timestamps[index] = it.BeginTs()
}
// set result.SuccIndex
sliceIndex := make([]uint32, rowNums)
for i := uint32(0); i < rowNums; i++ {
sliceIndex[i] = i
}
it.result.SuccIndex = sliceIndex
if it.schema.EnableDynamicField {
err := checkDynamicFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
if err != nil {
return err
}
}
if Params.CommonCfg.EnableNamespace.GetAsBool() {
err := addNamespaceData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
if err != nil {
return err
}
}
if err := checkAndFlattenStructFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg); err != nil {
return err
}
allFields := typeutil.GetAllFieldSchemas(it.schema.CollectionSchema)
// use the passed pk as new pk when autoID == false
// automatic generate pk as new pk wehen autoID == true
var err error
it.result.IDs, it.oldIDs, err = checkUpsertPrimaryFieldData(allFields, it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
if err != nil {
log.Warn("check primary field data and hash primary key failed when upsert",
zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
}
// check varchar/text with analyzer was utf-8 format
err = checkInputUtf8Compatiable(allFields, it.upsertMsg.InsertMsg)
if err != nil {
log.Warn("check varchar/text format failed", zap.Error(err))
return err
}
// Validate and set field ID to insert field data
err = validateFieldDataColumns(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
if err != nil {
log.Warn("validate field data columns failed when upsert", zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
}
err = fillFieldPropertiesOnly(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
if err != nil {
log.Warn("fill field properties failed when upsert", zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
}
if it.partitionKeyMode {
fieldSchema, _ := typeutil.GetPartitionKeyFieldSchema(it.schema.CollectionSchema)
it.partitionKeys, err = getPartitionKeyFieldData(fieldSchema, it.upsertMsg.InsertMsg)
if err != nil {
log.Warn("get partition keys from insert request failed",
zap.String("collectionName", collectionName),
zap.Error(err))
return err
}
} else {
partitionTag := it.upsertMsg.InsertMsg.PartitionName
if err = validatePartitionTag(partitionTag, true); err != nil {
log.Warn("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err))
return err
}
}
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.schemaHelper, it.upsertMsg.InsertMsg.NRows()); err != nil {
return err
}
log.Debug("Proxy Upsert insertPreExecute done")
return nil
}
func (it *upsertTask) deletePreExecute(ctx context.Context) error {
collName := it.upsertMsg.DeleteMsg.CollectionName
log := log.Ctx(ctx).With(
zap.String("collectionName", collName))
if it.upsertMsg.DeleteMsg.PrimaryKeys == nil {
// if primary keys are not set by queryPreExecute, use oldIDs to delete all given records
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs
}
if typeutil.GetSizeOfIDs(it.upsertMsg.DeleteMsg.PrimaryKeys) == 0 {
log.Info("deletePKs is empty, skip deleteExecute")
return nil
}
if err := validateCollectionName(collName); err != nil {
log.Info("Invalid collectionName", zap.Error(err))
return err
}
if it.partitionKeyMode {
// multi entities with same pk and diff partition keys may be hashed to multi physical partitions
// if deleteMsg.partitionID = common.InvalidPartition,
// all segments with this pk under the collection will have the delete record
it.upsertMsg.DeleteMsg.PartitionID = common.AllPartitionsID
} else {
// partition name could be defaultPartitionName or name specified by sdk
partName := it.upsertMsg.DeleteMsg.PartitionName
if err := validatePartitionTag(partName, true); err != nil {
log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
return err
}
partID, err := globalMetaCache.GetPartitionID(ctx, it.req.GetDbName(), collName, partName)
if err != nil {
log.Warn("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
return err
}
it.upsertMsg.DeleteMsg.PartitionID = partID
}
it.upsertMsg.DeleteMsg.Timestamps = make([]uint64, it.upsertMsg.DeleteMsg.NumRows)
for index := range it.upsertMsg.DeleteMsg.Timestamps {
it.upsertMsg.DeleteMsg.Timestamps[index] = it.BeginTs()
}
log.Debug("Proxy Upsert deletePreExecute done")
return nil
}
func (it *upsertTask) PreExecute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-PreExecute")
defer sp.End()
collectionName := it.req.CollectionName
log := log.Ctx(ctx).With(zap.String("collectionName", collectionName))
it.result = &milvuspb.MutationResult{
Status: merr.Success(),
IDs: &schemapb.IDs{
IdField: nil,
},
Timestamp: it.EndTs(),
}
replicateID, err := GetReplicateID(ctx, it.req.GetDbName(), collectionName)
if err != nil {
log.Warn("get replicate info failed", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
if replicateID != "" {
return merr.WrapErrCollectionReplicateMode("upsert")
}
// check collection exists
collID, err := globalMetaCache.GetCollectionID(context.Background(), it.req.GetDbName(), collectionName)
if err != nil {
log.Warn("fail to get collection id", zap.Error(err))
return err
}
it.collectionID = collID
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, it.req.GetDbName(), collectionName, collID)
if err != nil {
log.Warn("fail to get collection info", zap.Error(err))
return err
}
if it.schemaTimestamp != 0 {
if it.schemaTimestamp != colInfo.updateTimestamp {
err := merr.WrapErrCollectionSchemaMisMatch(collectionName)
log.Info("collection schema mismatch", zap.String("collectionName", collectionName),
zap.Uint64("requestSchemaTs", it.schemaTimestamp),
zap.Uint64("collectionSchemaTs", colInfo.updateTimestamp),
zap.Error(err))
return err
}
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName)
if err != nil {
log.Warn("Failed to get collection schema",
zap.String("collectionName", collectionName),
zap.Error(err))
return err
}
it.schema = schema
err = common.CheckNamespace(schema.CollectionSchema, it.req.Namespace)
if err != nil {
return err
}
it.partitionKeyMode, err = isPartitionKeyMode(ctx, it.req.GetDbName(), collectionName)
if err != nil {
log.Warn("check partition key mode failed",
zap.String("collectionName", collectionName),
zap.Error(err))
return err
}
if it.partitionKeyMode {
if len(it.req.GetPartitionName()) > 0 {
return errors.New("not support manually specifying the partition names if partition key mode is used")
}
} else {
// set default partition name if not use partition key
// insert to _default partition
partitionTag := it.req.GetPartitionName()
if len(partitionTag) <= 0 {
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.req.GetDbName(), collectionName, "")
if err != nil {
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
it.req.PartitionName = pinfo.name
}
}
// check for duplicate primary keys in the same batch
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
if err != nil {
log.Warn("fail to get primary field schema", zap.Error(err))
return err
}
duplicate, err := CheckDuplicatePkExist(primaryFieldSchema, it.req.GetFieldsData())
if err != nil {
log.Warn("fail to check duplicate primary keys", zap.Error(err))
return err
}
if duplicate {
return merr.WrapErrParameterInvalidMsg("duplicate primary keys are not allowed in the same batch")
}
it.upsertMsg = &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionName: it.req.CollectionName,
CollectionID: it.collectionID,
PartitionName: it.req.PartitionName,
FieldsData: it.req.FieldsData,
NumRows: uint64(it.req.NumRows),
Version: msgpb.InsertDataVersion_ColumnBased,
DbName: it.req.DbName,
Namespace: it.req.Namespace,
},
},
DeleteMsg: &msgstream.DeleteMsg{
DeleteRequest: &msgpb.DeleteRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
DbName: it.req.DbName,
CollectionName: it.req.CollectionName,
CollectionID: it.collectionID,
NumRows: int64(it.req.NumRows),
PartitionName: it.req.PartitionName,
},
},
}
// check if num_rows is valid
if it.req.NumRows <= 0 {
return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0")
}
if err := genFunctionFields(ctx, it.upsertMsg.InsertMsg, it.schema, it.req.GetPartialUpdate()); err != nil {
return err
}
if it.req.GetPartialUpdate() {
err = it.queryPreExecute(ctx)
if err != nil {
log.Warn("Fail to queryPreExecute", zap.Error(err))
return err
}
// reconstruct upsert msg after queryPreExecute
it.upsertMsg.InsertMsg.FieldsData = it.insertFieldData
it.upsertMsg.DeleteMsg.PrimaryKeys = it.deletePKs
it.upsertMsg.DeleteMsg.NumRows = int64(typeutil.GetSizeOfIDs(it.deletePKs))
}
err = it.insertPreExecute(ctx)
if err != nil {
log.Warn("Fail to insertPreExecute", zap.Error(err))
return err
}
err = it.deletePreExecute(ctx)
if err != nil {
log.Warn("Fail to deletePreExecute", zap.Error(err))
return err
}
it.result.DeleteCnt = it.upsertMsg.DeleteMsg.NumRows
it.result.InsertCnt = int64(it.upsertMsg.InsertMsg.NumRows)
if it.result.DeleteCnt != it.result.InsertCnt {
log.Info("DeleteCnt and InsertCnt are not the same when upsert",
zap.Int64("DeleteCnt", it.result.DeleteCnt),
zap.Int64("InsertCnt", it.result.InsertCnt))
}
it.result.UpsertCnt = it.result.InsertCnt
log.Debug("Proxy Upsert PreExecute done")
return nil
}
func (it *upsertTask) PostExecute(ctx context.Context) error {
return nil
}