diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 31417b6c73..f5b86302a0 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -2257,6 +2257,21 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream. BaseMsg: getBaseMsg(ctx, ts), OperatePrivilegeV2Request: r, } + case *milvuspb.CreatePrivilegeGroupRequest: + tsMsg = &msgstream.CreatePrivilegeGroupMsg{ + BaseMsg: getBaseMsg(ctx, ts), + CreatePrivilegeGroupRequest: r, + } + case *milvuspb.DropPrivilegeGroupRequest: + tsMsg = &msgstream.DropPrivilegeGroupMsg{ + BaseMsg: getBaseMsg(ctx, ts), + DropPrivilegeGroupRequest: r, + } + case *milvuspb.OperatePrivilegeGroupRequest: + tsMsg = &msgstream.OperatePrivilegeGroupMsg{ + BaseMsg: getBaseMsg(ctx, ts), + OperatePrivilegeGroupRequest: r, + } case *milvuspb.CreateAliasRequest: tsMsg = &msgstream.CreateAliasMsg{ BaseMsg: getBaseMsg(ctx, ts), diff --git a/pkg/mq/msgstream/msg_for_user_role.go b/pkg/mq/msgstream/msg_for_user_role.go index 7de44db1f1..c818bbb5bb 100644 --- a/pkg/mq/msgstream/msg_for_user_role.go +++ b/pkg/mq/msgstream/msg_for_user_role.go @@ -447,3 +447,162 @@ func (c *OperatePrivilegeV2Msg) Unmarshal(input MarshalType) (TsMsg, error) { func (c *OperatePrivilegeV2Msg) Size() int { return proto.Size(c.OperatePrivilegeV2Request) } + +type CreatePrivilegeGroupMsg struct { + BaseMsg + *milvuspb.CreatePrivilegeGroupRequest +} + +var _ TsMsg = &CreatePrivilegeGroupMsg{} + +func (c *CreatePrivilegeGroupMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *CreatePrivilegeGroupMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *CreatePrivilegeGroupMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *CreatePrivilegeGroupMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *CreatePrivilegeGroupMsg) Marshal(input TsMsg) (MarshalType, error) { + createPrivilegeGroupMsg := input.(*CreatePrivilegeGroupMsg) + createPrivilegeGroupRequest := createPrivilegeGroupMsg.CreatePrivilegeGroupRequest + mb, err := proto.Marshal(createPrivilegeGroupRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *CreatePrivilegeGroupMsg) Unmarshal(input MarshalType) (TsMsg, error) { + createPrivilegeGroupRequest := &milvuspb.CreatePrivilegeGroupRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, createPrivilegeGroupRequest) + if err != nil { + return nil, err + } + createPrivilegeGroupMsg := &CreatePrivilegeGroupMsg{CreatePrivilegeGroupRequest: createPrivilegeGroupRequest} + createPrivilegeGroupMsg.BeginTimestamp = createPrivilegeGroupMsg.GetBase().GetTimestamp() + createPrivilegeGroupMsg.EndTimestamp = createPrivilegeGroupMsg.GetBase().GetTimestamp() + return createPrivilegeGroupMsg, nil +} + +func (c *CreatePrivilegeGroupMsg) Size() int { + return proto.Size(c.CreatePrivilegeGroupRequest) +} + +type DropPrivilegeGroupMsg struct { + BaseMsg + *milvuspb.DropPrivilegeGroupRequest +} + +var _ TsMsg = &DropPrivilegeGroupMsg{} + +func (c *DropPrivilegeGroupMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *DropPrivilegeGroupMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *DropPrivilegeGroupMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *DropPrivilegeGroupMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *DropPrivilegeGroupMsg) Marshal(input TsMsg) (MarshalType, error) { + dropPrivilegeGroupMsg := input.(*DropPrivilegeGroupMsg) + dropPrivilegeGroupRequest := dropPrivilegeGroupMsg.DropPrivilegeGroupRequest + mb, err := proto.Marshal(dropPrivilegeGroupRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *DropPrivilegeGroupMsg) Unmarshal(input MarshalType) (TsMsg, error) { + dropPrivilegeGroupRequest := &milvuspb.DropPrivilegeGroupRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, dropPrivilegeGroupRequest) + if err != nil { + return nil, err + } + dropPrivilegeGroupMsg := &DropPrivilegeGroupMsg{DropPrivilegeGroupRequest: dropPrivilegeGroupRequest} + dropPrivilegeGroupMsg.BeginTimestamp = dropPrivilegeGroupMsg.GetBase().GetTimestamp() + dropPrivilegeGroupMsg.EndTimestamp = dropPrivilegeGroupMsg.GetBase().GetTimestamp() + return dropPrivilegeGroupMsg, nil +} + +func (c *DropPrivilegeGroupMsg) Size() int { + return proto.Size(c.DropPrivilegeGroupRequest) +} + +type OperatePrivilegeGroupMsg struct { + BaseMsg + *milvuspb.OperatePrivilegeGroupRequest +} + +var _ TsMsg = &OperatePrivilegeGroupMsg{} + +func (c *OperatePrivilegeGroupMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *OperatePrivilegeGroupMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *OperatePrivilegeGroupMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *OperatePrivilegeGroupMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *OperatePrivilegeGroupMsg) Marshal(input TsMsg) (MarshalType, error) { + operatePrivilegeGroupMsg := input.(*OperatePrivilegeGroupMsg) + operatePrivilegeGroupRequest := operatePrivilegeGroupMsg.OperatePrivilegeGroupRequest + mb, err := proto.Marshal(operatePrivilegeGroupRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *OperatePrivilegeGroupMsg) Unmarshal(input MarshalType) (TsMsg, error) { + operatePrivilegeGroupRequest := &milvuspb.OperatePrivilegeGroupRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, operatePrivilegeGroupRequest) + if err != nil { + return nil, err + } + operatePrivilegeGroupMsg := &OperatePrivilegeGroupMsg{OperatePrivilegeGroupRequest: operatePrivilegeGroupRequest} + operatePrivilegeGroupMsg.BeginTimestamp = operatePrivilegeGroupMsg.GetBase().GetTimestamp() + operatePrivilegeGroupMsg.EndTimestamp = operatePrivilegeGroupMsg.GetBase().GetTimestamp() + return operatePrivilegeGroupMsg, nil +} + +func (c *OperatePrivilegeGroupMsg) Size() int { + return proto.Size(c.OperatePrivilegeGroupRequest) +} diff --git a/pkg/mq/msgstream/msg_for_user_role_test.go b/pkg/mq/msgstream/msg_for_user_role_test.go index e28fec6138..51a284295b 100644 --- a/pkg/mq/msgstream/msg_for_user_role_test.go +++ b/pkg/mq/msgstream/msg_for_user_role_test.go @@ -354,3 +354,119 @@ func TestOperatePrivilegeV2(t *testing.T) { assert.EqualValues(t, "unit_user", newMsg.(*OperatePrivilegeV2Msg).GetGrantor().GetUser().GetName()) assert.EqualValues(t, "unit_privilege", newMsg.(*OperatePrivilegeV2Msg).GetGrantor().GetPrivilege().GetName()) } + +func TestCreatePrivilegeGroup(t *testing.T) { + var msg TsMsg = &CreatePrivilegeGroupMsg{ + CreatePrivilegeGroupRequest: &milvuspb.CreatePrivilegeGroupRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreatePrivilegeGroup, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + GroupName: "unit_group", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_CreatePrivilegeGroup, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &CreatePrivilegeGroupMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_group", newMsg.(*CreatePrivilegeGroupMsg).GetGroupName()) + assert.EqualValues(t, commonpb.MsgType_CreatePrivilegeGroup, newMsg.Type()) + + assert.True(t, msg.Size() > 0) +} + +func TestDropPrivilegeGroup(t *testing.T) { + var msg TsMsg = &DropPrivilegeGroupMsg{ + DropPrivilegeGroupRequest: &milvuspb.DropPrivilegeGroupRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropPrivilegeGroup, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + GroupName: "unit_group", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_DropPrivilegeGroup, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &DropPrivilegeGroupMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_group", newMsg.(*DropPrivilegeGroupMsg).GetGroupName()) + + assert.True(t, msg.Size() > 0) +} + +func TestOperatePrivilegeGroup(t *testing.T) { + var msg TsMsg = &OperatePrivilegeGroupMsg{ + OperatePrivilegeGroupRequest: &milvuspb.OperatePrivilegeGroupRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_OperatePrivilegeGroup, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + GroupName: "unit_group", + Type: milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup, + Privileges: []*milvuspb.PrivilegeEntity{ + {Name: "unit_privilege"}, + }, + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_OperatePrivilegeGroup, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &OperatePrivilegeGroupMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_group", newMsg.(*OperatePrivilegeGroupMsg).GetGroupName()) + assert.EqualValues(t, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup, newMsg.(*OperatePrivilegeGroupMsg).GetType()) + assert.EqualValues(t, "unit_privilege", newMsg.(*OperatePrivilegeGroupMsg).GetPrivileges()[0].GetName()) +} diff --git a/pkg/mq/msgstream/unmarshal.go b/pkg/mq/msgstream/unmarshal.go index 4ac0888c6d..a5e74a5f3e 100644 --- a/pkg/mq/msgstream/unmarshal.go +++ b/pkg/mq/msgstream/unmarshal.go @@ -89,6 +89,9 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { operateUserRoleMsg := OperateUserRoleMsg{} operatePrivilegeMsg := OperatePrivilegeMsg{} operatePrivilegeV2Msg := OperatePrivilegeV2Msg{} + createPrivilegeGroupMsg := CreatePrivilegeGroupMsg{} + dropPrivilegeGroupMsg := DropPrivilegeGroupMsg{} + operatePrivilegeGroupMsg := OperatePrivilegeGroupMsg{} replicateMsg := ReplicateMsg{} importMsg := ImportMsg{} @@ -128,6 +131,9 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { p.TempMap[commonpb.MsgType_OperateUserRole] = operateUserRoleMsg.Unmarshal p.TempMap[commonpb.MsgType_OperatePrivilege] = operatePrivilegeMsg.Unmarshal p.TempMap[commonpb.MsgType_OperatePrivilegeV2] = operatePrivilegeV2Msg.Unmarshal + p.TempMap[commonpb.MsgType_CreatePrivilegeGroup] = createPrivilegeGroupMsg.Unmarshal + p.TempMap[commonpb.MsgType_DropPrivilegeGroup] = dropPrivilegeGroupMsg.Unmarshal + p.TempMap[commonpb.MsgType_OperatePrivilegeGroup] = operatePrivilegeGroupMsg.Unmarshal p.TempMap[commonpb.MsgType_Replicate] = replicateMsg.Unmarshal p.TempMap[commonpb.MsgType_Import] = importMsg.Unmarshal p.TempMap[commonpb.MsgType_CreateAlias] = createAliasMsg.Unmarshal