mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-28 14:35:27 +08:00
related: #45993 This commit extends nullable vector support to the proxy layer, querynode, and adds comprehensive validation, search reduce, and field data handling for nullable vectors with sparse storage. Proxy layer changes: - Update validate_util.go checkAligned() with getExpectedVectorRows() helper to validate nullable vector field alignment using valid data count - Update checkFloatVectorFieldData/checkSparseFloatVectorFieldData for nullable vector validation with proper row count expectations - Add FieldDataIdxComputer in typeutil/schema.go for logical-to-physical index translation during search reduce operations - Update search_reduce_util.go reduceSearchResultData to use idxComputers for correct field data indexing with nullable vectors - Update task.go, task_query.go, task_upsert.go for nullable vector handling - Update msg_pack.go with nullable vector field data processing QueryNode layer changes: - Update segments/result.go for nullable vector result handling - Update segments/search_reduce.go with nullable vector offset translation Storage and index changes: - Update data_codec.go and utils.go for nullable vector serialization - Update indexcgowrapper/dataset.go and index.go for nullable vector indexing Utility changes: - Add FieldDataIdxComputer struct with Compute() method for efficient logical-to-physical index mapping across multiple field data - Update EstimateEntitySize() and AppendFieldData() with fieldIdxs parameter - Update funcutil.go with nullable vector support functions <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Full support for nullable vector fields (float, binary, float16, bfloat16, int8, sparse) across ingest, storage, indexing, search and retrieval; logical↔physical offset mapping preserves row semantics. * Client: compaction control and compaction-state APIs. * **Bug Fixes** * Improved validation for adding vector fields (nullable + dimension checks) and corrected search/query behavior for nullable vectors. * **Chores** * Persisted validity maps with indexes and on-disk formats. * **Tests** * Extensive new and updated end-to-end nullable-vector tests. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
1205 lines
38 KiB
Go
1205 lines
38 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/samber/lo"
|
|
"go.opentelemetry.io/otel"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/allocator"
|
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
"github.com/milvus-io/milvus/internal/util/segcore"
|
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
|
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
|
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
type upsertTask struct {
|
|
baseTask
|
|
Condition
|
|
|
|
upsertMsg *msgstream.UpsertMsg
|
|
req *milvuspb.UpsertRequest
|
|
baseMsg msgstream.BaseMsg
|
|
|
|
ctx context.Context
|
|
|
|
timestamps []uint64
|
|
rowIDs []int64
|
|
result *milvuspb.MutationResult
|
|
idAllocator *allocator.IDAllocator
|
|
collectionID UniqueID
|
|
chMgr channelsMgr
|
|
chTicker channelsTimeTicker
|
|
vChannels []vChan
|
|
pChannels []pChan
|
|
schema *schemaInfo
|
|
partitionKeyMode bool
|
|
partitionKeys *schemapb.FieldData
|
|
// automatic generate pk as new pk wehen autoID == true
|
|
// delete task need use the oldIDs
|
|
oldIDs *schemapb.IDs
|
|
schemaTimestamp uint64
|
|
|
|
// write after read, generate write part by queryPreExecute
|
|
node types.ProxyComponent
|
|
|
|
deletePKs *schemapb.IDs
|
|
insertFieldData []*schemapb.FieldData
|
|
|
|
storageCost segcore.StorageCost
|
|
}
|
|
|
|
// TraceCtx returns upsertTask context
|
|
func (it *upsertTask) TraceCtx() context.Context {
|
|
return it.ctx
|
|
}
|
|
|
|
func (it *upsertTask) ID() UniqueID {
|
|
return it.req.Base.MsgID
|
|
}
|
|
|
|
func (it *upsertTask) SetID(uid UniqueID) {
|
|
it.req.Base.MsgID = uid
|
|
}
|
|
|
|
func (it *upsertTask) Name() string {
|
|
return UpsertTaskName
|
|
}
|
|
|
|
func (it *upsertTask) Type() commonpb.MsgType {
|
|
return it.req.Base.MsgType
|
|
}
|
|
|
|
func (it *upsertTask) BeginTs() Timestamp {
|
|
return it.baseMsg.BeginTimestamp
|
|
}
|
|
|
|
func (it *upsertTask) SetTs(ts Timestamp) {
|
|
it.baseMsg.BeginTimestamp = ts
|
|
it.baseMsg.EndTimestamp = ts
|
|
}
|
|
|
|
func (it *upsertTask) EndTs() Timestamp {
|
|
return it.baseMsg.EndTimestamp
|
|
}
|
|
|
|
func (it *upsertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
|
|
ret := make(map[pChan]pChanStatistics)
|
|
|
|
channels := it.getChannels()
|
|
|
|
beginTs := it.BeginTs()
|
|
endTs := it.EndTs()
|
|
|
|
for _, channel := range channels {
|
|
ret[channel] = pChanStatistics{
|
|
minTs: beginTs,
|
|
maxTs: endTs,
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (it *upsertTask) setChannels() error {
|
|
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.req.GetDbName(), it.req.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
channels, err := it.chMgr.getChannels(collID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
it.pChannels = channels
|
|
return nil
|
|
}
|
|
|
|
func (it *upsertTask) getChannels() []pChan {
|
|
return it.pChannels
|
|
}
|
|
|
|
func (it *upsertTask) OnEnqueue() error {
|
|
if it.req.Base == nil {
|
|
it.req.Base = commonpbutil.NewMsgBase()
|
|
}
|
|
it.req.Base.MsgType = commonpb.MsgType_Upsert
|
|
it.req.Base.SourceID = paramtable.GetNodeID()
|
|
return nil
|
|
}
|
|
|
|
func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, segcore.StorageCost, error) {
|
|
log := log.Ctx(ctx).With(zap.String("collectionName", t.req.GetCollectionName()))
|
|
var err error
|
|
queryReq := &milvuspb.QueryRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
Timestamp: t.BeginTs(),
|
|
},
|
|
DbName: t.req.GetDbName(),
|
|
CollectionName: t.req.GetCollectionName(),
|
|
ConsistencyLevel: commonpb.ConsistencyLevel_Strong,
|
|
NotReturnAllMeta: false,
|
|
OutputFields: []string{"*"},
|
|
UseDefaultConsistency: false,
|
|
GuaranteeTimestamp: t.BeginTs(),
|
|
Namespace: t.req.Namespace,
|
|
}
|
|
pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema)
|
|
if err != nil {
|
|
return nil, segcore.StorageCost{}, err
|
|
}
|
|
|
|
var partitionIDs []int64
|
|
if t.partitionKeyMode {
|
|
// multi entities with same pk and diff partition keys may be hashed to multi physical partitions
|
|
// if deleteMsg.partitionID = common.InvalidPartition,
|
|
// all segments with this pk under the collection will have the delete record
|
|
partitionIDs = []int64{common.AllPartitionsID}
|
|
queryReq.PartitionNames = []string{}
|
|
} else {
|
|
// partition name could be defaultPartitionName or name specified by sdk
|
|
partName := t.upsertMsg.DeleteMsg.PartitionName
|
|
if err := validatePartitionTag(partName, true); err != nil {
|
|
log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
|
|
return nil, segcore.StorageCost{}, err
|
|
}
|
|
partID, err := globalMetaCache.GetPartitionID(ctx, t.req.GetDbName(), t.req.GetCollectionName(), partName)
|
|
if err != nil {
|
|
log.Warn("Failed to get partition id", zap.String("partitionName", partName), zap.Error(err))
|
|
return nil, segcore.StorageCost{}, err
|
|
}
|
|
partitionIDs = []int64{partID}
|
|
queryReq.PartitionNames = []string{partName}
|
|
}
|
|
|
|
plan := planparserv2.CreateRequeryPlan(pkField, ids)
|
|
plan.Namespace = t.req.Namespace
|
|
qt := &queryTask{
|
|
ctx: t.ctx,
|
|
Condition: NewTaskCondition(t.ctx),
|
|
RetrieveRequest: &internalpb.RetrieveRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Retrieve),
|
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
|
),
|
|
ReqID: paramtable.GetNodeID(),
|
|
PartitionIDs: partitionIDs,
|
|
ConsistencyLevel: commonpb.ConsistencyLevel_Strong,
|
|
},
|
|
request: queryReq,
|
|
plan: plan,
|
|
mixCoord: t.node.(*Proxy).mixCoord,
|
|
lb: t.node.(*Proxy).lbPolicy,
|
|
}
|
|
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-retrieveByPKs")
|
|
defer func() {
|
|
sp.End()
|
|
}()
|
|
queryResult, storageCost, err := t.node.(*Proxy).query(ctx, qt, sp)
|
|
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
|
|
return nil, storageCost, err
|
|
}
|
|
return queryResult, storageCost, err
|
|
}
|
|
|
|
func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
|
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
|
|
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema.CollectionSchema)
|
|
if err != nil {
|
|
log.Warn("get primary field schema failed", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
primaryFieldData, err := typeutil.GetPrimaryFieldData(it.req.GetFieldsData(), primaryFieldSchema)
|
|
if err != nil {
|
|
log.Error("get primary field data failed", zap.Error(err))
|
|
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldSchema.Name))
|
|
}
|
|
|
|
upsertIDs, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
|
if err != nil {
|
|
log.Warn("parse primary field data to IDs failed", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
upsertIDSize := typeutil.GetSizeOfIDs(upsertIDs)
|
|
if upsertIDSize == 0 {
|
|
it.deletePKs = &schemapb.IDs{}
|
|
it.insertFieldData = it.req.GetFieldsData()
|
|
log.Info("old records not found, just do insert")
|
|
return nil
|
|
}
|
|
|
|
tr := timerecord.NewTimeRecorder("Proxy-Upsert-retrieveByPKs")
|
|
// retrieve by primary key to get original field data
|
|
resp, storageCost, err := retrieveByPKs(ctx, it, upsertIDs, []string{"*"})
|
|
if err != nil {
|
|
log.Info("retrieve by primary key failed", zap.Error(err))
|
|
return err
|
|
}
|
|
it.storageCost = storageCost
|
|
if len(resp.GetFieldsData()) == 0 {
|
|
return merr.WrapErrParameterInvalidMsg("retrieve by primary key failed, no data found")
|
|
}
|
|
|
|
existFieldData := resp.GetFieldsData()
|
|
pkFieldData, err := typeutil.GetPrimaryFieldData(existFieldData, primaryFieldSchema)
|
|
if err != nil {
|
|
log.Error("get primary field data failed", zap.Error(err))
|
|
return err
|
|
}
|
|
existIDs, err := parsePrimaryFieldData2IDs(pkFieldData)
|
|
if err != nil {
|
|
log.Info("parse primary field data to ids failed", zap.Error(err))
|
|
return err
|
|
}
|
|
log.Info("retrieveByPKs cost",
|
|
zap.Int("resultNum", typeutil.GetSizeOfIDs(existIDs)),
|
|
zap.Int64("latency", tr.ElapseSpan().Milliseconds()))
|
|
|
|
// set field id for user passed field data, prepare for merge logic
|
|
if len(it.upsertMsg.InsertMsg.GetFieldsData()) == 0 {
|
|
return merr.WrapErrParameterInvalidMsg("upsert field data is empty")
|
|
}
|
|
for _, fieldData := range it.upsertMsg.InsertMsg.GetFieldsData() {
|
|
fieldName := fieldData.GetFieldName()
|
|
if fieldData.GetIsDynamic() {
|
|
fieldName = "$meta"
|
|
}
|
|
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldName)
|
|
if err != nil {
|
|
log.Info("get field schema failed", zap.Error(err))
|
|
return err
|
|
}
|
|
fieldData.FieldId = fieldSchema.GetFieldID()
|
|
fieldData.FieldName = fieldName
|
|
|
|
// compatible with different nullable data format from sdk
|
|
if len(fieldData.GetValidData()) != 0 {
|
|
err := FillWithNullValue(fieldData, fieldSchema, int(it.upsertMsg.InsertMsg.NRows()))
|
|
if err != nil {
|
|
log.Info("unify null field data format failed", zap.Error(err))
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// Validate field data alignment before processing to prevent index out of range panic
|
|
if err := newValidateUtil().checkAligned(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.schemaHelper, uint64(upsertIDSize)); err != nil {
|
|
log.Warn("check field data aligned failed", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
// Two nullable data formats are supported:
|
|
//
|
|
// COMPRESSED FORMAT (SDK format, before validateUtil.fillWithValue processing):
|
|
// Logical data: [1, null, 2]
|
|
// Storage: Data=[1, 2] + ValidData=[true, false, true]
|
|
// - Data array contains only non-null values (compressed)
|
|
// - ValidData array tracks null positions for all rows
|
|
//
|
|
// FULL FORMAT (Milvus internal format, after validateUtil.fillWithValue processing):
|
|
// Logical data: [1, null, 2]
|
|
// Storage: Data=[1, 0, 2] + ValidData=[true, false, true]
|
|
// - Data array contains values for all rows (nulls filled with zero/default)
|
|
// - ValidData array still tracks null positions
|
|
//
|
|
// Note: we will unify the nullable format to FULL FORMAT before executing the merge logic
|
|
insertIdxInUpsert := make([]int, 0)
|
|
updateIdxInUpsert := make([]int, 0)
|
|
// 1. split upsert data into insert and update by query result
|
|
idsChecker, err := typeutil.NewIDsChecker(existIDs)
|
|
if err != nil {
|
|
log.Info("create primary key checker failed", zap.Error(err))
|
|
return err
|
|
}
|
|
for upsertIdx := 0; upsertIdx < upsertIDSize; upsertIdx++ {
|
|
exist, err := idsChecker.Contains(upsertIDs, upsertIdx)
|
|
if err != nil {
|
|
log.Info("check primary key exist in query result failed", zap.Error(err))
|
|
return err
|
|
}
|
|
if exist {
|
|
updateIdxInUpsert = append(updateIdxInUpsert, upsertIdx)
|
|
} else {
|
|
insertIdxInUpsert = append(insertIdxInUpsert, upsertIdx)
|
|
}
|
|
}
|
|
|
|
// 2. merge field data on update semantic
|
|
it.deletePKs = &schemapb.IDs{}
|
|
it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize))
|
|
|
|
if len(updateIdxInUpsert) > 0 {
|
|
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
|
// but query results fields do not set valid data when returning default value fields,
|
|
// therefore valid data needs to be manually set to true
|
|
for _, fieldData := range existFieldData {
|
|
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName())
|
|
if err != nil {
|
|
log.Info("get field schema failed", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
if fieldSchema.GetDefaultValue() != nil {
|
|
fieldData.ValidData = make([]bool, upsertIDSize)
|
|
for i := range fieldData.ValidData {
|
|
fieldData.ValidData[i] = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build mapping from existing primary keys to their positions in query result
|
|
// This ensures we can correctly locate data even if query results are not in the same order as request
|
|
existIDsLen := typeutil.GetSizeOfIDs(existIDs)
|
|
existPKToIndex := make(map[interface{}]int, existIDsLen)
|
|
for j := 0; j < existIDsLen; j++ {
|
|
pk := typeutil.GetPK(existIDs, int64(j))
|
|
existPKToIndex[pk] = j
|
|
}
|
|
|
|
baseIdx := 0
|
|
idxComputer := typeutil.NewFieldDataIdxComputer(existFieldData)
|
|
for _, idx := range updateIdxInUpsert {
|
|
typeutil.AppendIDs(it.deletePKs, upsertIDs, idx)
|
|
oldPK := typeutil.GetPK(upsertIDs, int64(idx))
|
|
existIndex, ok := existPKToIndex[oldPK]
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
|
|
}
|
|
fieldIdxs := idxComputer.Compute(int64(existIndex))
|
|
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex), fieldIdxs...)
|
|
err := typeutil.UpdateFieldData(it.insertFieldData, it.upsertMsg.InsertMsg.GetFieldsData(), int64(baseIdx), int64(idx))
|
|
baseIdx += 1
|
|
if err != nil {
|
|
log.Info("update field data failed", zap.Error(err))
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. merge field data on insert semantic
|
|
if len(insertIdxInUpsert) > 0 {
|
|
// if necessary field is not exist in upsert request, return error
|
|
lackOfFieldErr := LackOfFieldsDataBySchema(it.schema.CollectionSchema, it.upsertMsg.InsertMsg.GetFieldsData(), false, true)
|
|
if lackOfFieldErr != nil {
|
|
log.Info("check fields data by schema failed", zap.Error(lackOfFieldErr))
|
|
return lackOfFieldErr
|
|
}
|
|
|
|
// if the nullable field has not passed in upsert request, which means the len(upsertFieldData) < len(it.insertFieldData)
|
|
// we need to generate the nullable field data before append as insert
|
|
insertWithNullField := make([]*schemapb.FieldData, 0)
|
|
upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (string, *schemapb.FieldData) {
|
|
return field.GetFieldName(), field
|
|
})
|
|
for _, fieldSchema := range it.schema.CollectionSchema.Fields {
|
|
if fieldData, ok := upsertFieldMap[fieldSchema.Name]; !ok {
|
|
if fieldSchema.GetNullable() || fieldSchema.GetDefaultValue() != nil {
|
|
fieldData, err := GenNullableFieldData(fieldSchema, upsertIDSize)
|
|
if err != nil {
|
|
log.Info("generate nullable field data failed", zap.Error(err))
|
|
return err
|
|
}
|
|
insertWithNullField = append(insertWithNullField, fieldData)
|
|
}
|
|
} else {
|
|
insertWithNullField = append(insertWithNullField, fieldData)
|
|
}
|
|
}
|
|
vectorIdxMap := make([][]int64, len(insertIdxInUpsert))
|
|
for rowIdx, offset := range insertIdxInUpsert {
|
|
vectorIdxMap[rowIdx] = make([]int64, len(insertWithNullField))
|
|
for fieldIdx := range insertWithNullField {
|
|
vectorIdxMap[rowIdx][fieldIdx] = int64(offset)
|
|
}
|
|
}
|
|
for fieldIdx, fieldData := range insertWithNullField {
|
|
validData := fieldData.GetValidData()
|
|
if len(validData) > 0 && typeutil.IsVectorType(fieldData.Type) {
|
|
dataIdx := int64(0)
|
|
rowIdx := 0
|
|
for i := 0; i < len(validData) && rowIdx < len(insertIdxInUpsert); i++ {
|
|
if i == insertIdxInUpsert[rowIdx] {
|
|
vectorIdxMap[rowIdx][fieldIdx] = dataIdx
|
|
rowIdx++
|
|
}
|
|
if validData[i] {
|
|
dataIdx++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for rowIdx, idx := range insertIdxInUpsert {
|
|
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx), vectorIdxMap[rowIdx]...)
|
|
}
|
|
}
|
|
|
|
for _, fieldData := range it.insertFieldData {
|
|
if len(fieldData.GetValidData()) > 0 {
|
|
err := ToCompressedFormatNullable(fieldData)
|
|
if err != nil {
|
|
log.Info("convert to compressed format nullable failed", zap.Error(err))
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ToCompressedFormatNullable converts the field data from full format nullable to compressed format nullable
|
|
func ToCompressedFormatNullable(field *schemapb.FieldData) error {
|
|
if getValidNumber(field.GetValidData()) == len(field.GetValidData()) {
|
|
return nil
|
|
}
|
|
switch field.Field.(type) {
|
|
case *schemapb.FieldData_Scalars:
|
|
switch sd := field.GetScalars().GetData().(type) {
|
|
case *schemapb.ScalarField_BoolData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.BoolData.Data = make([]bool, 0)
|
|
} else {
|
|
ret := make([]bool, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.BoolData.Data[i])
|
|
}
|
|
}
|
|
sd.BoolData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_IntData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.IntData.Data = make([]int32, 0)
|
|
} else {
|
|
ret := make([]int32, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.IntData.Data[i])
|
|
}
|
|
}
|
|
sd.IntData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_LongData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.LongData.Data = make([]int64, 0)
|
|
} else {
|
|
ret := make([]int64, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.LongData.Data[i])
|
|
}
|
|
}
|
|
sd.LongData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_FloatData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.FloatData.Data = make([]float32, 0)
|
|
} else {
|
|
ret := make([]float32, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.FloatData.Data[i])
|
|
}
|
|
}
|
|
sd.FloatData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_DoubleData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.DoubleData.Data = make([]float64, 0)
|
|
} else {
|
|
ret := make([]float64, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.DoubleData.Data[i])
|
|
}
|
|
}
|
|
sd.DoubleData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_StringData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.StringData.Data = make([]string, 0)
|
|
} else {
|
|
ret := make([]string, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.StringData.Data[i])
|
|
}
|
|
}
|
|
sd.StringData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_JsonData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.JsonData.Data = make([][]byte, 0)
|
|
} else {
|
|
ret := make([][]byte, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.JsonData.Data[i])
|
|
}
|
|
}
|
|
sd.JsonData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_ArrayData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.ArrayData.Data = make([]*schemapb.ScalarField, 0)
|
|
} else {
|
|
ret := make([]*schemapb.ScalarField, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.ArrayData.Data[i])
|
|
}
|
|
}
|
|
sd.ArrayData.Data = ret
|
|
}
|
|
case *schemapb.ScalarField_TimestamptzData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.TimestamptzData.Data = make([]int64, 0)
|
|
} else {
|
|
ret := make([]int64, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.TimestamptzData.Data[i])
|
|
}
|
|
}
|
|
sd.TimestamptzData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_GeometryWktData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.GeometryWktData.Data = make([]string, 0)
|
|
} else {
|
|
ret := make([]string, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.GeometryWktData.Data[i])
|
|
}
|
|
}
|
|
sd.GeometryWktData.Data = ret
|
|
}
|
|
|
|
case *schemapb.ScalarField_GeometryData:
|
|
validRowNum := getValidNumber(field.GetValidData())
|
|
if validRowNum == 0 {
|
|
sd.GeometryData.Data = make([][]byte, 0)
|
|
} else {
|
|
ret := make([][]byte, 0, validRowNum)
|
|
for i, valid := range field.GetValidData() {
|
|
if valid {
|
|
ret = append(ret, sd.GeometryData.Data[i])
|
|
}
|
|
}
|
|
sd.GeometryData.Data = ret
|
|
}
|
|
|
|
default:
|
|
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
|
}
|
|
|
|
case *schemapb.FieldData_Vectors:
|
|
// Vector data is already in compressed format, skip
|
|
return nil
|
|
|
|
default:
|
|
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenNullableFieldData generates nullable field data in FULL FORMAT
|
|
func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schemapb.FieldData, error) {
|
|
switch field.DataType {
|
|
case schemapb.DataType_Bool:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_BoolData{
|
|
BoolData: &schemapb.BoolArray{
|
|
Data: make([]bool, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_Int32:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_IntData{
|
|
IntData: &schemapb.IntArray{
|
|
Data: make([]int32, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_Int64:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: make([]int64, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_Float:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_FloatData{
|
|
FloatData: &schemapb.FloatArray{
|
|
Data: make([]float32, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_Double:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_DoubleData{
|
|
DoubleData: &schemapb.DoubleArray{
|
|
Data: make([]float64, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_VarChar:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: make([]string, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_JSON:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_JsonData{
|
|
JsonData: &schemapb.JSONArray{
|
|
Data: make([][]byte, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_Array:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_ArrayData{
|
|
ArrayData: &schemapb.ArrayArray{
|
|
Data: make([]*schemapb.ScalarField, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
case schemapb.DataType_Timestamptz:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_TimestamptzData{
|
|
TimestamptzData: &schemapb.TimestamptzArray{
|
|
Data: make([]int64, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
|
|
// the intput data of geometry field is in wkt format
|
|
case schemapb.DataType_Geometry:
|
|
return &schemapb.FieldData{
|
|
FieldId: field.FieldID,
|
|
FieldName: field.Name,
|
|
Type: field.DataType,
|
|
IsDynamic: field.IsDynamic,
|
|
ValidData: make([]bool, upsertIDSize),
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_GeometryWktData{
|
|
GeometryWktData: &schemapb.GeometryWktArray{
|
|
Data: make([]string, upsertIDSize),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil
|
|
default:
|
|
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined scalar data type:%s", field.DataType.String()))
|
|
}
|
|
}
|
|
|
|
func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute")
|
|
defer sp.End()
|
|
collectionName := it.upsertMsg.InsertMsg.CollectionName
|
|
if err := validateCollectionName(collectionName); err != nil {
|
|
log.Ctx(ctx).Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
bm25Fields := typeutil.NewSet[string](GetBM25FunctionOutputFields(it.schema.CollectionSchema)...)
|
|
if it.req.PartialUpdate {
|
|
// remove the old bm25 fields
|
|
ret := make([]*schemapb.FieldData, 0)
|
|
for _, fieldData := range it.upsertMsg.InsertMsg.GetFieldsData() {
|
|
if bm25Fields.Contain(fieldData.GetFieldName()) {
|
|
continue
|
|
}
|
|
ret = append(ret, fieldData)
|
|
}
|
|
it.upsertMsg.InsertMsg.FieldsData = ret
|
|
}
|
|
|
|
rowNums := uint32(it.upsertMsg.InsertMsg.NRows())
|
|
// set upsertTask.insertRequest.rowIDs
|
|
tr := timerecord.NewTimeRecorder("applyPK")
|
|
clusterID := Params.CommonCfg.ClusterID.GetAsUint64()
|
|
rowIDBegin, rowIDEnd, _ := common.AllocAutoID(it.idAllocator.Alloc, rowNums, clusterID)
|
|
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
|
|
|
it.upsertMsg.InsertMsg.RowIDs = make([]UniqueID, rowNums)
|
|
it.rowIDs = make([]UniqueID, rowNums)
|
|
for i := rowIDBegin; i < rowIDEnd; i++ {
|
|
offset := i - rowIDBegin
|
|
it.upsertMsg.InsertMsg.RowIDs[offset] = i
|
|
it.rowIDs[offset] = i
|
|
}
|
|
// set upsertTask.insertRequest.timeStamps
|
|
rowNum := it.upsertMsg.InsertMsg.NRows()
|
|
it.upsertMsg.InsertMsg.Timestamps = make([]uint64, rowNum)
|
|
it.timestamps = make([]uint64, rowNum)
|
|
for index := range it.timestamps {
|
|
it.upsertMsg.InsertMsg.Timestamps[index] = it.BeginTs()
|
|
it.timestamps[index] = it.BeginTs()
|
|
}
|
|
// set result.SuccIndex
|
|
sliceIndex := make([]uint32, rowNums)
|
|
for i := uint32(0); i < rowNums; i++ {
|
|
sliceIndex[i] = i
|
|
}
|
|
it.result.SuccIndex = sliceIndex
|
|
|
|
if it.schema.EnableDynamicField {
|
|
err := checkDynamicFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if Params.CommonCfg.EnableNamespace.GetAsBool() {
|
|
err := addNamespaceData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := checkAndFlattenStructFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg); err != nil {
|
|
return err
|
|
}
|
|
|
|
allFields := typeutil.GetAllFieldSchemas(it.schema.CollectionSchema)
|
|
|
|
// use the passed pk as new pk when autoID == false
|
|
// automatic generate pk as new pk wehen autoID == true
|
|
var err error
|
|
it.result.IDs, it.oldIDs, err = checkUpsertPrimaryFieldData(allFields, it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
|
|
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
|
|
if err != nil {
|
|
log.Warn("check primary field data and hash primary key failed when upsert",
|
|
zap.Error(err))
|
|
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
|
|
}
|
|
|
|
// check varchar/text with analyzer was utf-8 format
|
|
err = checkInputUtf8Compatiable(allFields, it.upsertMsg.InsertMsg)
|
|
if err != nil {
|
|
log.Warn("check varchar/text format failed", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
// Validate and set field ID to insert field data
|
|
err = validateFieldDataColumns(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
|
|
if err != nil {
|
|
log.Warn("validate field data columns failed when upsert", zap.Error(err))
|
|
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
|
|
}
|
|
err = fillFieldPropertiesOnly(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema)
|
|
if err != nil {
|
|
log.Warn("fill field properties failed when upsert", zap.Error(err))
|
|
return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid)
|
|
}
|
|
|
|
if it.partitionKeyMode {
|
|
fieldSchema, _ := typeutil.GetPartitionKeyFieldSchema(it.schema.CollectionSchema)
|
|
it.partitionKeys, err = getPartitionKeyFieldData(fieldSchema, it.upsertMsg.InsertMsg)
|
|
if err != nil {
|
|
log.Warn("get partition keys from insert request failed",
|
|
zap.String("collectionName", collectionName),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
} else {
|
|
partitionTag := it.upsertMsg.InsertMsg.PartitionName
|
|
if err = validatePartitionTag(partitionTag, true); err != nil {
|
|
log.Warn("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err))
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()).
|
|
Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.schemaHelper, it.upsertMsg.InsertMsg.NRows()); err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Debug("Proxy Upsert insertPreExecute done")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (it *upsertTask) deletePreExecute(ctx context.Context) error {
|
|
collName := it.upsertMsg.DeleteMsg.CollectionName
|
|
log := log.Ctx(ctx).With(
|
|
zap.String("collectionName", collName))
|
|
|
|
if it.upsertMsg.DeleteMsg.PrimaryKeys == nil {
|
|
// if primary keys are not set by queryPreExecute, use oldIDs to delete all given records
|
|
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs
|
|
}
|
|
if typeutil.GetSizeOfIDs(it.upsertMsg.DeleteMsg.PrimaryKeys) == 0 {
|
|
log.Info("deletePKs is empty, skip deleteExecute")
|
|
return nil
|
|
}
|
|
|
|
if err := validateCollectionName(collName); err != nil {
|
|
log.Info("Invalid collectionName", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
if it.partitionKeyMode {
|
|
// multi entities with same pk and diff partition keys may be hashed to multi physical partitions
|
|
// if deleteMsg.partitionID = common.InvalidPartition,
|
|
// all segments with this pk under the collection will have the delete record
|
|
it.upsertMsg.DeleteMsg.PartitionID = common.AllPartitionsID
|
|
} else {
|
|
// partition name could be defaultPartitionName or name specified by sdk
|
|
partName := it.upsertMsg.DeleteMsg.PartitionName
|
|
if err := validatePartitionTag(partName, true); err != nil {
|
|
log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
|
|
return err
|
|
}
|
|
partID, err := globalMetaCache.GetPartitionID(ctx, it.req.GetDbName(), collName, partName)
|
|
if err != nil {
|
|
log.Warn("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
|
|
return err
|
|
}
|
|
it.upsertMsg.DeleteMsg.PartitionID = partID
|
|
}
|
|
|
|
it.upsertMsg.DeleteMsg.Timestamps = make([]uint64, it.upsertMsg.DeleteMsg.NumRows)
|
|
for index := range it.upsertMsg.DeleteMsg.Timestamps {
|
|
it.upsertMsg.DeleteMsg.Timestamps[index] = it.BeginTs()
|
|
}
|
|
log.Debug("Proxy Upsert deletePreExecute done")
|
|
return nil
|
|
}
|
|
|
|
func (it *upsertTask) PreExecute(ctx context.Context) error {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-PreExecute")
|
|
defer sp.End()
|
|
|
|
collectionName := it.req.CollectionName
|
|
log := log.Ctx(ctx).With(zap.String("collectionName", collectionName))
|
|
|
|
it.result = &milvuspb.MutationResult{
|
|
Status: merr.Success(),
|
|
IDs: &schemapb.IDs{
|
|
IdField: nil,
|
|
},
|
|
Timestamp: it.EndTs(),
|
|
}
|
|
|
|
replicateID, err := GetReplicateID(ctx, it.req.GetDbName(), collectionName)
|
|
if err != nil {
|
|
log.Warn("get replicate info failed", zap.String("collectionName", collectionName), zap.Error(err))
|
|
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
|
|
}
|
|
if replicateID != "" {
|
|
return merr.WrapErrCollectionReplicateMode("upsert")
|
|
}
|
|
|
|
// check collection exists
|
|
collID, err := globalMetaCache.GetCollectionID(context.Background(), it.req.GetDbName(), collectionName)
|
|
if err != nil {
|
|
log.Warn("fail to get collection id", zap.Error(err))
|
|
return err
|
|
}
|
|
it.collectionID = collID
|
|
|
|
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, it.req.GetDbName(), collectionName, collID)
|
|
if err != nil {
|
|
log.Warn("fail to get collection info", zap.Error(err))
|
|
return err
|
|
}
|
|
if it.schemaTimestamp != 0 {
|
|
if it.schemaTimestamp != colInfo.updateTimestamp {
|
|
err := merr.WrapErrCollectionSchemaMisMatch(collectionName)
|
|
log.Info("collection schema mismatch", zap.String("collectionName", collectionName),
|
|
zap.Uint64("requestSchemaTs", it.schemaTimestamp),
|
|
zap.Uint64("collectionSchemaTs", colInfo.updateTimestamp),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
}
|
|
|
|
schema, err := globalMetaCache.GetCollectionSchema(ctx, it.req.GetDbName(), collectionName)
|
|
if err != nil {
|
|
log.Warn("Failed to get collection schema",
|
|
zap.String("collectionName", collectionName),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
it.schema = schema
|
|
|
|
err = common.CheckNamespace(schema.CollectionSchema, it.req.Namespace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
it.partitionKeyMode, err = isPartitionKeyMode(ctx, it.req.GetDbName(), collectionName)
|
|
if err != nil {
|
|
log.Warn("check partition key mode failed",
|
|
zap.String("collectionName", collectionName),
|
|
zap.Error(err))
|
|
return err
|
|
}
|
|
if it.partitionKeyMode {
|
|
if len(it.req.GetPartitionName()) > 0 {
|
|
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
|
}
|
|
} else {
|
|
// set default partition name if not use partition key
|
|
// insert to _default partition
|
|
partitionTag := it.req.GetPartitionName()
|
|
if len(partitionTag) <= 0 {
|
|
pinfo, err := globalMetaCache.GetPartitionInfo(ctx, it.req.GetDbName(), collectionName, "")
|
|
if err != nil {
|
|
log.Warn("get partition info failed", zap.String("collectionName", collectionName), zap.Error(err))
|
|
return err
|
|
}
|
|
it.req.PartitionName = pinfo.name
|
|
}
|
|
}
|
|
|
|
// check for duplicate primary keys in the same batch
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
|
if err != nil {
|
|
log.Warn("fail to get primary field schema", zap.Error(err))
|
|
return err
|
|
}
|
|
duplicate, err := CheckDuplicatePkExist(primaryFieldSchema, it.req.GetFieldsData())
|
|
if err != nil {
|
|
log.Warn("fail to check duplicate primary keys", zap.Error(err))
|
|
return err
|
|
}
|
|
if duplicate {
|
|
return merr.WrapErrParameterInvalidMsg("duplicate primary keys are not allowed in the same batch")
|
|
}
|
|
|
|
it.upsertMsg = &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
|
),
|
|
CollectionName: it.req.CollectionName,
|
|
CollectionID: it.collectionID,
|
|
PartitionName: it.req.PartitionName,
|
|
FieldsData: it.req.FieldsData,
|
|
NumRows: uint64(it.req.NumRows),
|
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
|
DbName: it.req.DbName,
|
|
Namespace: it.req.Namespace,
|
|
},
|
|
},
|
|
DeleteMsg: &msgstream.DeleteMsg{
|
|
DeleteRequest: &msgpb.DeleteRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
|
|
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
|
),
|
|
DbName: it.req.DbName,
|
|
CollectionName: it.req.CollectionName,
|
|
CollectionID: it.collectionID,
|
|
NumRows: int64(it.req.NumRows),
|
|
PartitionName: it.req.PartitionName,
|
|
},
|
|
},
|
|
}
|
|
|
|
// check if num_rows is valid
|
|
if it.req.NumRows <= 0 {
|
|
return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0")
|
|
}
|
|
|
|
if err := genFunctionFields(ctx, it.upsertMsg.InsertMsg, it.schema, it.req.GetPartialUpdate()); err != nil {
|
|
return err
|
|
}
|
|
|
|
if it.req.GetPartialUpdate() {
|
|
err = it.queryPreExecute(ctx)
|
|
if err != nil {
|
|
log.Warn("Fail to queryPreExecute", zap.Error(err))
|
|
return err
|
|
}
|
|
// reconstruct upsert msg after queryPreExecute
|
|
it.upsertMsg.InsertMsg.FieldsData = it.insertFieldData
|
|
it.upsertMsg.DeleteMsg.PrimaryKeys = it.deletePKs
|
|
it.upsertMsg.DeleteMsg.NumRows = int64(typeutil.GetSizeOfIDs(it.deletePKs))
|
|
}
|
|
|
|
err = it.insertPreExecute(ctx)
|
|
if err != nil {
|
|
log.Warn("Fail to insertPreExecute", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
err = it.deletePreExecute(ctx)
|
|
if err != nil {
|
|
log.Warn("Fail to deletePreExecute", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
it.result.DeleteCnt = it.upsertMsg.DeleteMsg.NumRows
|
|
it.result.InsertCnt = int64(it.upsertMsg.InsertMsg.NumRows)
|
|
if it.result.DeleteCnt != it.result.InsertCnt {
|
|
log.Info("DeleteCnt and InsertCnt are not the same when upsert",
|
|
zap.Int64("DeleteCnt", it.result.DeleteCnt),
|
|
zap.Int64("InsertCnt", it.result.InsertCnt))
|
|
}
|
|
it.result.UpsertCnt = it.result.InsertCnt
|
|
log.Debug("Proxy Upsert PreExecute done")
|
|
return nil
|
|
}
|
|
|
|
func (it *upsertTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|