diff --git a/internal/proxy/task_delete_streaming.go b/internal/proxy/task_delete_streaming.go index deec98d9b8..88a29ae811 100644 --- a/internal/proxy/task_delete_streaming.go +++ b/internal/proxy/task_delete_streaming.go @@ -8,6 +8,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -38,6 +39,17 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { return err } + var ez *message.CipherConfig + if hookutil.IsClusterEncyptionEnabled() { + schema, err := globalMetaCache.GetCollectionSchema(ctx, dt.req.GetDbName(), dt.req.GetCollectionName()) + if err != nil { + log.Ctx(ctx).Warn("get collection schema from global meta cache failed", zap.String("collectionName", dt.req.GetCollectionName()), zap.Error(err)) + return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) + } + + ez = hookutil.GetEzByCollProperties(schema.GetProperties(), dt.collectionID).AsMessageConfig() + } + var msgs []message.MutableMessage for hashKey, deleteMsgs := range result { vchannel := dt.vChannels[hashKey] @@ -49,6 +61,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) { }). WithBody(deleteMsg.DeleteRequest). WithVChannel(vchannel). + WithCipher(ez). BuildMutable() if err != nil { return err diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 63c365db6e..910724f4b1 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -36,6 +36,7 @@ type insertTask struct { schema *schemapb.CollectionSchema partitionKeys *schemapb.FieldData schemaTimestamp uint64 + collectionID int64 } // TraceCtx returns insertTask context @@ -137,6 +138,8 @@ func (it *insertTask) PreExecute(ctx context.Context) error { log.Ctx(ctx).Warn("fail to get collection id", zap.Error(err)) return err } + it.collectionID = collID + colInfo, err := globalMetaCache.GetCollectionInfo(ctx, it.insertMsg.GetDbName(), collectionName, collID) if err != nil { log.Ctx(ctx).Warn("fail to get collection info", zap.Error(err)) diff --git a/internal/proxy/task_insert_streaming.go b/internal/proxy/task_insert_streaming.go index 90d79fa312..2cd08b8909 100644 --- a/internal/proxy/task_insert_streaming.go +++ b/internal/proxy/task_insert_streaming.go @@ -10,6 +10,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" @@ -50,12 +51,17 @@ func (it *insertTask) Execute(ctx context.Context) error { zap.Bool("is_parition_key", it.partitionKeys != nil), zap.Duration("get cache duration", getCacheDur)) + var ez *message.CipherConfig + if hookutil.IsClusterEncyptionEnabled() { + ez = hookutil.GetEzByCollProperties(it.schema.GetProperties(), it.collectionID).AsMessageConfig() + } + // start to repack insert data var msgs []message.MutableMessage if it.partitionKeys == nil { - msgs, err = repackInsertDataForStreamingService(it.TraceCtx(), channelNames, it.insertMsg, it.result) + msgs, err = repackInsertDataForStreamingService(it.TraceCtx(), channelNames, it.insertMsg, it.result, ez) } else { - msgs, err = repackInsertDataWithPartitionKeyForStreamingService(it.TraceCtx(), channelNames, it.insertMsg, it.result, it.partitionKeys) + msgs, err = repackInsertDataWithPartitionKeyForStreamingService(it.TraceCtx(), channelNames, it.insertMsg, it.result, it.partitionKeys, ez) } if err != nil { log.Warn("assign segmentID and repack insert data failed", zap.Error(err)) @@ -77,6 +83,7 @@ func repackInsertDataForStreamingService( channelNames []string, insertMsg *msgstream.InsertMsg, result *milvuspb.MutationResult, + ez *message.CipherConfig, ) ([]message.MutableMessage, error) { messages := make([]message.MutableMessage, 0) @@ -107,6 +114,7 @@ func repackInsertDataForStreamingService( }, }). WithBody(insertRequest). + WithCipher(ez). BuildMutable() if err != nil { return nil, err @@ -123,6 +131,7 @@ func repackInsertDataWithPartitionKeyForStreamingService( insertMsg *msgstream.InsertMsg, result *milvuspb.MutationResult, partitionKeys *schemapb.FieldData, + ez *message.CipherConfig, ) ([]message.MutableMessage, error) { messages := make([]message.MutableMessage, 0) @@ -186,6 +195,7 @@ func repackInsertDataWithPartitionKeyForStreamingService( }, }). WithBody(insertRequest). + WithCipher(ez). BuildMutable() if err != nil { return nil, err diff --git a/internal/proxy/task_upsert_streaming.go b/internal/proxy/task_upsert_streaming.go index 5ee6766cd7..0e146f8626 100644 --- a/internal/proxy/task_upsert_streaming.go +++ b/internal/proxy/task_upsert_streaming.go @@ -8,6 +8,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus/internal/distributed/streaming" + "github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -20,12 +21,17 @@ func (ut *upsertTask) Execute(ctx context.Context) error { defer sp.End() log := log.Ctx(ctx).With(zap.String("collectionName", ut.req.CollectionName)) - insertMsgs, err := ut.packInsertMessage(ctx) + var ez *message.CipherConfig + if hookutil.IsClusterEncyptionEnabled() { + ez = hookutil.GetEzByCollProperties(ut.schema.GetProperties(), ut.collectionID).AsMessageConfig() + } + + insertMsgs, err := ut.packInsertMessage(ctx, ez) if err != nil { log.Warn("pack insert message failed", zap.Error(err)) return err } - deleteMsgs, err := ut.packDeleteMessage(ctx) + deleteMsgs, err := ut.packDeleteMessage(ctx, ez) if err != nil { log.Warn("pack delete message failed", zap.Error(err)) return err @@ -42,7 +48,7 @@ func (ut *upsertTask) Execute(ctx context.Context) error { return nil } -func (ut *upsertTask) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) { +func (ut *upsertTask) packInsertMessage(ctx context.Context, ez *message.CipherConfig) ([]message.MutableMessage, error) { tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", ut.ID())) defer tr.Elapse("insert execute done when insertExecute") @@ -77,9 +83,9 @@ func (ut *upsertTask) packInsertMessage(ctx context.Context) ([]message.MutableM // start to repack insert data var msgs []message.MutableMessage if ut.partitionKeys == nil { - msgs, err = repackInsertDataForStreamingService(ut.TraceCtx(), channelNames, ut.upsertMsg.InsertMsg, ut.result) + msgs, err = repackInsertDataForStreamingService(ut.TraceCtx(), channelNames, ut.upsertMsg.InsertMsg, ut.result, ez) } else { - msgs, err = repackInsertDataWithPartitionKeyForStreamingService(ut.TraceCtx(), channelNames, ut.upsertMsg.InsertMsg, ut.result, ut.partitionKeys) + msgs, err = repackInsertDataWithPartitionKeyForStreamingService(ut.TraceCtx(), channelNames, ut.upsertMsg.InsertMsg, ut.result, ut.partitionKeys, ez) } if err != nil { log.Warn("assign segmentID and repack insert data failed", zap.Error(err)) @@ -89,27 +95,27 @@ func (ut *upsertTask) packInsertMessage(ctx context.Context) ([]message.MutableM return msgs, nil } -func (it *upsertTask) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) { - tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID())) - collID := it.upsertMsg.DeleteMsg.CollectionID - it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs +func (ut *upsertTask) packDeleteMessage(ctx context.Context, ez *message.CipherConfig) ([]message.MutableMessage, error) { + tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", ut.ID())) + collID := ut.upsertMsg.DeleteMsg.CollectionID + ut.upsertMsg.DeleteMsg.PrimaryKeys = ut.oldIDs log := log.Ctx(ctx).With( zap.Int64("collectionID", collID)) // hash primary keys to channels - vChannels, err := it.chMgr.getVChannels(collID) + vChannels, err := ut.chMgr.getVChannels(collID) if err != nil { log.Warn("get vChannels failed when deleteExecute", zap.Error(err)) - it.result.Status = merr.Status(err) + ut.result.Status = merr.Status(err) return nil, err } result, numRows, err := repackDeleteMsgByHash( ctx, - it.upsertMsg.DeleteMsg.PrimaryKeys, - vChannels, it.idAllocator, - it.BeginTs(), - it.upsertMsg.DeleteMsg.CollectionID, it.upsertMsg.DeleteMsg.CollectionName, - it.upsertMsg.DeleteMsg.PartitionID, it.upsertMsg.DeleteMsg.PartitionName, - it.req.GetDbName(), + ut.upsertMsg.DeleteMsg.PrimaryKeys, + vChannels, ut.idAllocator, + ut.BeginTs(), + ut.upsertMsg.DeleteMsg.CollectionID, ut.upsertMsg.DeleteMsg.CollectionName, + ut.upsertMsg.DeleteMsg.PartitionID, ut.upsertMsg.DeleteMsg.PartitionName, + ut.req.GetDbName(), ) if err != nil { return nil, err @@ -121,7 +127,7 @@ func (it *upsertTask) packDeleteMessage(ctx context.Context) ([]message.MutableM for _, deleteMsg := range deleteMsgs { msg, err := message.NewDeleteMessageBuilderV1(). WithHeader(&message.DeleteMessageHeader{ - CollectionId: it.upsertMsg.DeleteMsg.CollectionID, + CollectionId: ut.upsertMsg.DeleteMsg.CollectionID, Rows: uint64(deleteMsg.NumRows), }). WithBody(deleteMsg.DeleteRequest). @@ -137,7 +143,7 @@ func (it *upsertTask) packDeleteMessage(ctx context.Context) ([]message.MutableM log.Debug("Proxy Upsert deleteExecute done", zap.Int64("collection_id", collID), zap.Strings("virtual_channels", vChannels), - zap.Int64("taskID", it.ID()), + zap.Int64("taskID", ut.ID()), zap.Int64("numRows", numRows), zap.Duration("prepare duration", tr.ElapseSpan())) diff --git a/internal/util/hookutil/cipher.go b/internal/util/hookutil/cipher.go index 02448420c9..bc2d2d282f 100644 --- a/internal/util/hookutil/cipher.go +++ b/internal/util/hookutil/cipher.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/hook" "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) @@ -67,6 +68,13 @@ type EZ struct { CollectionID int64 } +func (ez *EZ) AsMessageConfig() *message.CipherConfig { + if ez == nil { + return nil + } + return &message.CipherConfig{EzID: ez.EzID, CollectionID: ez.CollectionID} +} + type CipherContext struct { EZ key []byte