diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 0d714f6b5c..53912a39ca 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1377,13 +1377,17 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) ctx: ctx, Condition: NewTaskCondition(ctx), req: deleteReq, - DeleteRequest: &internalpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - SourceID: Params.ProxyID, + BaseDeleteTask: BaseDeleteTask{ + BaseMsg: msgstream.BaseMsg{}, + DeleteRequest: internalpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + MsgID: 0, + }, + CollectionName: request.CollectionName, + PartitionName: request.PartitionName, + // RowData: transfer column based request to this }, - CollectionName: request.CollectionName, - PartitionName: request.PartitionName, }, chMgr: node.chMgr, chTicker: node.chTicker, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 63ee334361..874b431333 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -583,7 +583,7 @@ func (it *insertTask) transferColumnBasedRequestToRowBasedData() error { return nil } -func (it *insertTask) checkFieldAutoID() error { +func (it *insertTask) checkFieldAutoIDAndHashPK() error { // TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields if it.req.NumRows <= 0 { return errNumRowsLessThanOrEqualToZero(it.req.NumRows) @@ -694,30 +694,16 @@ func (it *insertTask) checkFieldAutoID() error { }, } - // TODO(dragondriver): in this case, should we directly overwrite the hash? - - if len(it.HashValues) != 0 && len(it.HashValues) != len(it.BaseInsertTask.RowIDs) { - return fmt.Errorf("invalid length of input hash values") - } - if it.HashValues == nil || len(it.HashValues) <= 0 { - it.HashValues = make([]uint32, 0, len(it.BaseInsertTask.RowIDs)) - for _, rowID := range it.BaseInsertTask.RowIDs { - hash, _ := typeutil.Hash32Int64(rowID) - it.HashValues = append(it.HashValues, hash) - } + it.HashValues = make([]uint32, 0, len(it.BaseInsertTask.RowIDs)) + for _, rowID := range it.BaseInsertTask.RowIDs { + hash, _ := typeutil.Hash32Int64(rowID) + it.HashValues = append(it.HashValues, hash) } } else { - // use primary keys as hash if hash is not provided - // in this case, primary field is required, we have already checked this - if uint32(len(it.HashValues)) != 0 && uint32(len(it.HashValues)) != rowNums { - return fmt.Errorf("invalid length of input hash values") - } - if it.HashValues == nil || len(it.HashValues) <= 0 { - it.HashValues = make([]uint32, 0, len(primaryData)) - for _, pk := range primaryData { - hash, _ := typeutil.Hash32Int64(pk) - it.HashValues = append(it.HashValues, hash) - } + it.HashValues = make([]uint32, 0, len(primaryData)) + for _, pk := range primaryData { + hash, _ := typeutil.Hash32Int64(pk) + it.HashValues = append(it.HashValues, hash) } } @@ -768,7 +754,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return err } - err = it.checkFieldAutoID() + err = it.checkFieldAutoIDAndHashPK() if err != nil { return err } @@ -824,7 +810,6 @@ func (it *insertTask) _assignSegmentID(stream msgstream.MsgStream, pack *msgstre if keysLen != timestampLen || keysLen != rowIDLen || keysLen != rowDataLen { return nil, fmt.Errorf("the length of hashValue, timestamps, rowIDs, RowData are not equal") } - for idx, channelID := range keys { channelCountMap[channelID]++ if _, ok := channelMaxTSMap[channelID]; !ok { @@ -4645,9 +4630,11 @@ func (rpt *releasePartitionsTask) PostExecute(ctx context.Context) error { return nil } +type BaseDeleteTask = msgstream.DeleteMsg + type deleteTask struct { Condition - *internalpb.DeleteRequest + BaseDeleteTask ctx context.Context req *milvuspb.DeleteRequest result *milvuspb.MutationResult @@ -4772,7 +4759,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { log.Error("Failed to get primary keys from expr", zap.Error(err)) return err } - log.Debug("get primary keys from expr", zap.Any("primary keys", dt.DeleteRequest.PrimaryKeys)) + log.Debug("get primary keys from expr", zap.Any("primary keys", primaryKeys)) dt.DeleteRequest.PrimaryKeys = primaryKeys // set result @@ -4783,6 +4770,8 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error { } dt.result.DeleteCnt = int64(len(primaryKeys)) + dt.HashPK(primaryKeys) + rowNum := len(primaryKeys) dt.Timestamps = make([]uint64, rowNum) for index := range dt.Timestamps { @@ -4796,15 +4785,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { sp, ctx := trace.StartSpanFromContextWithOperationName(dt.ctx, "Proxy-Delete-Execute") defer sp.Finish() - var tsMsg msgstream.TsMsg = &msgstream.DeleteMsg{ - DeleteRequest: *dt.DeleteRequest, - BaseMsg: msgstream.BaseMsg{ - Ctx: ctx, - HashValues: []uint32{uint32(Params.ProxyID)}, - BeginTimestamp: dt.BeginTs(), - EndTimestamp: dt.EndTs(), - }, - } + var tsMsg msgstream.TsMsg = &dt.BaseDeleteTask msgPack := msgstream.MsgPack{ BeginTs: dt.BeginTs(), EndTs: dt.EndTs(), @@ -4839,8 +4820,64 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { return err } } + result := make(map[int32]msgstream.TsMsg) + hashKeys := stream.ComputeProduceChannelIndexes(msgPack.Msgs) + // For each msg, assign PK to different message buckets by hash value of PK. + for i, request := range msgPack.Msgs { + deleteRequest := request.(*msgstream.DeleteMsg) + keys := hashKeys[i] + collectionName := deleteRequest.CollectionName + collectionID := deleteRequest.CollectionID + partitionID := deleteRequest.PartitionID + partitionName := deleteRequest.PartitionName + proxyID := deleteRequest.Base.SourceID + for index, key := range keys { + ts := deleteRequest.Timestamps[index] + pks := deleteRequest.PrimaryKeys[index] + _, ok := result[key] + if !ok { + sliceRequest := internalpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + MsgID: dt.Base.MsgID, + Timestamp: ts, + SourceID: proxyID, + }, + CollectionID: collectionID, + PartitionID: partitionID, + CollectionName: collectionName, + PartitionName: partitionName, + } + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, + }, + DeleteRequest: sliceRequest, + } + result[key] = deleteMsg + } + curMsg := result[key].(*msgstream.DeleteMsg) + curMsg.HashValues = append(curMsg.HashValues, deleteRequest.HashValues[index]) + curMsg.Timestamps = append(curMsg.Timestamps, ts) + curMsg.PrimaryKeys = append(curMsg.PrimaryKeys, pks) + } + } - err = stream.Broadcast(&msgPack) + newPack := &msgstream.MsgPack{ + BeginTs: msgPack.BeginTs, + EndTs: msgPack.EndTs, + StartPositions: msgPack.StartPositions, + EndPositions: msgPack.EndPositions, + Msgs: make([]msgstream.TsMsg, 0), + } + + for _, msg := range result { + if msg != nil { + newPack.Msgs = append(newPack.Msgs, msg) + } + } + + err = stream.Produce(newPack) if err != nil { dt.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError dt.result.Status.Reason = err.Error() @@ -4853,6 +4890,14 @@ func (dt *deleteTask) PostExecute(ctx context.Context) error { return nil } +func (dt *deleteTask) HashPK(pks []int64) { + dt.HashValues = make([]uint32, 0, len(pks)) + for _, pk := range pks { + hash, _ := typeutil.Hash32Int64(pk) + dt.HashValues = append(dt.HashValues, hash) + } +} + type CreateAliasTask struct { Condition *milvuspb.CreateAliasRequest diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 3214a1857f..5cfe08f803 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -3345,15 +3345,18 @@ func TestTask_all(t *testing.T) { t.Run("delete", func(t *testing.T) { task := &deleteTask{ Condition: NewTaskCondition(ctx), - DeleteRequest: &internalpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - MsgID: 0, - Timestamp: 0, - SourceID: Params.ProxyID, + BaseDeleteTask: msgstream.DeleteMsg{ + BaseMsg: msgstream.BaseMsg{}, + DeleteRequest: internalpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + MsgID: 0, + Timestamp: 0, + SourceID: Params.ProxyID, + }, + CollectionName: collectionName, + PartitionName: partitionName, }, - CollectionName: collectionName, - PartitionName: partitionName, }, req: &milvuspb.DeleteRequest{ Base: &commonpb.MsgBase{