diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 65383987c9..1fb0785b40 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -420,6 +420,7 @@ func (node *Proxy) AlterDatabase(ctx context.Context, request *milvuspb.AlterDat Condition: NewTaskCondition(ctx), AlterDatabaseRequest: request, rootCoord: node.rootCoord, + replicateMsgStream: node.replicateMsgStream, } log := log.Ctx(ctx).With( @@ -4853,6 +4854,10 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre err = errors.Wrap(err, "encrypt password failed") return merr.Status(err), nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_CreateCredential credInfo := &internalpb.CredentialInfo{ Username: req.Username, @@ -4865,6 +4870,9 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre zap.Error(err)) return merr.Status(err), nil } + if merr.Ok(result) { + SendReplicateMessagePack(ctx, node.replicateMsgStream, req) + } return result, err } @@ -4922,6 +4930,10 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre err = errors.Wrap(err, "encrypt password failed") return merr.Status(err), nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_UpdateCredential updateCredReq := &internalpb.CredentialInfo{ Username: req.Username, Sha256Password: crypto.SHA256(rawNewPassword, req.Username), @@ -4933,6 +4945,9 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre zap.Error(err)) return merr.Status(err), nil } + if merr.Ok(result) { + SendReplicateMessagePack(ctx, node.replicateMsgStream, req) + } return result, err } @@ -4953,12 +4968,19 @@ func (node *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCre err := merr.WrapErrPrivilegeNotPermitted("root user cannot be deleted") return merr.Status(err), nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_DeleteCredential result, err := node.rootCoord.DeleteCredential(ctx, req) if err != nil { // for error like conntext timeout etc. log.Error("delete credential fail", zap.Error(err)) return merr.Status(err), nil } + if merr.Ok(result) { + SendReplicateMessagePack(ctx, node.replicateMsgStream, req) + } return result, err } @@ -4973,6 +4995,10 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.ListCredUsersResponse{Status: merr.Status(err)}, nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_ListCredUsernames rootCoordReq := &milvuspb.ListCredUsersRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_ListCredUsernames), @@ -5008,12 +5034,19 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque if err := ValidateRoleName(roleName); err != nil { return merr.Status(err), nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_CreateRole result, err := node.rootCoord.CreateRole(ctx, req) if err != nil { log.Warn("fail to create role", zap.Error(err)) return merr.Status(err), nil } + if merr.Ok(result) { + SendReplicateMessagePack(ctx, node.replicateMsgStream, req) + } return result, nil } @@ -5031,6 +5064,10 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) if err := ValidateRoleName(req.RoleName); err != nil { return merr.Status(err), nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_DropRole if IsDefaultRole(req.RoleName) { err := merr.WrapErrPrivilegeNotPermitted("the role[%s] is a default role, which can't be dropped", req.GetRoleName()) return merr.Status(err), nil @@ -5042,6 +5079,9 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) zap.Error(err)) return merr.Status(err), nil } + if merr.Ok(result) { + SendReplicateMessagePack(ctx, node.replicateMsgStream, req) + } return result, nil } @@ -5061,12 +5101,19 @@ func (node *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUse if err := ValidateRoleName(req.RoleName); err != nil { return merr.Status(err), nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_OperateUserRole result, err := node.rootCoord.OperateUserRole(ctx, req) if err != nil { log.Warn("fail to operate user role", zap.Error(err)) return merr.Status(err), nil } + if merr.Ok(result) { + SendReplicateMessagePack(ctx, node.replicateMsgStream, req) + } return result, nil } @@ -5088,6 +5135,10 @@ func (node *Proxy) SelectRole(ctx context.Context, req *milvuspb.SelectRoleReque }, nil } } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_SelectRole result, err := node.rootCoord.SelectRole(ctx, req) if err != nil { @@ -5118,6 +5169,10 @@ func (node *Proxy) SelectUser(ctx context.Context, req *milvuspb.SelectUserReque }, nil } } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_SelectUser result, err := node.rootCoord.SelectUser(ctx, req) if err != nil { @@ -5175,6 +5230,10 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr if err := node.validPrivilegeParams(req); err != nil { return merr.Status(err), nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_OperatePrivilege curUser, err := GetCurUserFromContext(ctx) if err != nil { log.Warn("fail to get current user", zap.Error(err)) @@ -5202,6 +5261,9 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr } } } + if merr.Ok(result) { + SendReplicateMessagePack(ctx, node.replicateMsgStream, req) + } return result, nil } @@ -5248,6 +5310,10 @@ func (node *Proxy) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantReq Status: merr.Status(err), }, nil } + if req.Base == nil { + req.Base = &commonpb.MsgBase{} + } + req.Base.MsgType = commonpb.MsgType_SelectGrant result, err := node.rootCoord.SelectGrant(ctx, req) if err != nil { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 5ace8e5108..25811b0f75 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -3895,7 +3895,7 @@ func testProxyRole(ctx context.Context, t *testing.T, proxy *Proxy) { resp, _ := proxy.CreateRole(ctx, &milvuspb.CreateRoleRequest{Entity: entity}) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) - entity.Name = "unit_test" + entity.Name = "unit_test1000" resp, _ = proxy.CreateRole(ctx, &milvuspb.CreateRoleRequest{Entity: entity}) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) diff --git a/internal/proxy/task_database.go b/internal/proxy/task_database.go index 1b26169f09..5ae997aab2 100644 --- a/internal/proxy/task_database.go +++ b/internal/proxy/task_database.go @@ -228,6 +228,8 @@ type alterDatabaseTask struct { ctx context.Context rootCoord types.RootCoordClient result *commonpb.Status + + replicateMsgStream msgstream.MsgStream } func (t *alterDatabaseTask) TraceCtx() context.Context { @@ -291,6 +293,7 @@ func (t *alterDatabaseTask) Execute(ctx context.Context) error { return err } + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterDatabaseRequest) t.result = ret return nil } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 43b3d33e5c..2a2ba490c2 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1578,6 +1578,11 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream. BaseMsg: getBaseMsg(ctx, ts), DropDatabaseRequest: r, } + case *milvuspb.AlterDatabaseRequest: + tsMsg = &msgstream.AlterDatabaseMsg{ + BaseMsg: getBaseMsg(ctx, ts), + AlterDatabaseRequest: r, + } case *milvuspb.FlushRequest: tsMsg = &msgstream.FlushMsg{ BaseMsg: getBaseMsg(ctx, ts), @@ -1618,6 +1623,41 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream. BaseMsg: getBaseMsg(ctx, ts), AlterIndexRequest: r, } + case *milvuspb.CreateCredentialRequest: + tsMsg = &msgstream.CreateUserMsg{ + BaseMsg: getBaseMsg(ctx, ts), + CreateCredentialRequest: r, + } + case *milvuspb.UpdateCredentialRequest: + tsMsg = &msgstream.UpdateUserMsg{ + BaseMsg: getBaseMsg(ctx, ts), + UpdateCredentialRequest: r, + } + case *milvuspb.DeleteCredentialRequest: + tsMsg = &msgstream.DeleteUserMsg{ + BaseMsg: getBaseMsg(ctx, ts), + DeleteCredentialRequest: r, + } + case *milvuspb.CreateRoleRequest: + tsMsg = &msgstream.CreateRoleMsg{ + BaseMsg: getBaseMsg(ctx, ts), + CreateRoleRequest: r, + } + case *milvuspb.DropRoleRequest: + tsMsg = &msgstream.DropRoleMsg{ + BaseMsg: getBaseMsg(ctx, ts), + DropRoleRequest: r, + } + case *milvuspb.OperateUserRoleRequest: + tsMsg = &msgstream.OperateUserRoleMsg{ + BaseMsg: getBaseMsg(ctx, ts), + OperateUserRoleRequest: r, + } + case *milvuspb.OperatePrivilegeRequest: + tsMsg = &msgstream.OperatePrivilegeMsg{ + BaseMsg: getBaseMsg(ctx, ts), + OperatePrivilegeRequest: r, + } default: log.Warn("unknown request", zap.Any("request", request)) return diff --git a/internal/querycoordv2/meta/segment_dist_manager.go b/internal/querycoordv2/meta/segment_dist_manager.go index a0c4013596..51d38fc0fc 100644 --- a/internal/querycoordv2/meta/segment_dist_manager.go +++ b/internal/querycoordv2/meta/segment_dist_manager.go @@ -57,7 +57,7 @@ func (f *ReplicaSegDistFilter) Match(s *Segment) bool { return f.GetCollectionID() == s.GetCollectionID() && f.Contains(s.Node) } -func (f ReplicaSegDistFilter) AddFilter(filter *segDistCriterion) { +func (f *ReplicaSegDistFilter) AddFilter(filter *segDistCriterion) { filter.nodes = f.GetNodes() filter.collectionID = f.GetCollectionID() } diff --git a/pkg/mq/msgstream/msg_for_database.go b/pkg/mq/msgstream/msg_for_database.go index c094ef76cd..1ae1782c30 100644 --- a/pkg/mq/msgstream/msg_for_database.go +++ b/pkg/mq/msgstream/msg_for_database.go @@ -131,3 +131,57 @@ func (d *DropDatabaseMsg) Unmarshal(input MarshalType) (TsMsg, error) { func (d *DropDatabaseMsg) Size() int { return proto.Size(d.DropDatabaseRequest) } + +type AlterDatabaseMsg struct { + BaseMsg + *milvuspb.AlterDatabaseRequest +} + +var _ TsMsg = &AlterDatabaseMsg{} + +func (a *AlterDatabaseMsg) ID() UniqueID { + return a.Base.MsgID +} + +func (a *AlterDatabaseMsg) SetID(id UniqueID) { + a.Base.MsgID = id +} + +func (a *AlterDatabaseMsg) Type() MsgType { + return a.Base.MsgType +} + +func (a *AlterDatabaseMsg) SourceID() int64 { + return a.Base.SourceID +} + +func (a *AlterDatabaseMsg) Marshal(input TsMsg) (MarshalType, error) { + alterDataBaseMsg := input.(*AlterDatabaseMsg) + alterDatabaseRequest := alterDataBaseMsg.AlterDatabaseRequest + mb, err := proto.Marshal(alterDatabaseRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (a *AlterDatabaseMsg) Unmarshal(input MarshalType) (TsMsg, error) { + alterDatabaseRequest := &milvuspb.AlterDatabaseRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, alterDatabaseRequest) + if err != nil { + return nil, err + } + alterDatabaseMsg := &AlterDatabaseMsg{AlterDatabaseRequest: alterDatabaseRequest} + alterDatabaseMsg.BeginTimestamp = alterDatabaseMsg.GetBase().GetTimestamp() + alterDatabaseMsg.EndTimestamp = alterDatabaseMsg.GetBase().GetTimestamp() + + return alterDatabaseMsg, nil +} + +func (a *AlterDatabaseMsg) Size() int { + return proto.Size(a.AlterDatabaseRequest) +} diff --git a/pkg/mq/msgstream/msg_for_database_test.go b/pkg/mq/msgstream/msg_for_database_test.go index 941b96ed8a..ba3ef2333e 100644 --- a/pkg/mq/msgstream/msg_for_database_test.go +++ b/pkg/mq/msgstream/msg_for_database_test.go @@ -100,3 +100,46 @@ func TestDropDatabase(t *testing.T) { assert.True(t, msg.Size() > 0) } + +func TestAlterDatabase(t *testing.T) { + var msg TsMsg = &AlterDatabaseMsg{ + AlterDatabaseRequest: &milvuspb.AlterDatabaseRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_AlterDatabase, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + Properties: []*commonpb.KeyValuePair{ + { + Key: "key", + Value: "value", + }, + }, + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_AlterDatabase, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &AlterDatabaseMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_db", newMsg.(*AlterDatabaseMsg).DbName) + assert.EqualValues(t, "key", newMsg.(*AlterDatabaseMsg).Properties[0].Key) + assert.EqualValues(t, "value", newMsg.(*AlterDatabaseMsg).Properties[0].Value) +} diff --git a/pkg/mq/msgstream/msg_for_user_role.go b/pkg/mq/msgstream/msg_for_user_role.go new file mode 100644 index 0000000000..543001aedd --- /dev/null +++ b/pkg/mq/msgstream/msg_for_user_role.go @@ -0,0 +1,396 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +type CreateUserMsg struct { + BaseMsg + *milvuspb.CreateCredentialRequest +} + +var _ TsMsg = &CreateUserMsg{} + +func (c *CreateUserMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *CreateUserMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *CreateUserMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *CreateUserMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *CreateUserMsg) Marshal(input TsMsg) (MarshalType, error) { + createUserMsg := input.(*CreateUserMsg) + createUserRequest := createUserMsg.CreateCredentialRequest + mb, err := proto.Marshal(createUserRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *CreateUserMsg) Unmarshal(input MarshalType) (TsMsg, error) { + createUserRequest := &milvuspb.CreateCredentialRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, createUserRequest) + if err != nil { + return nil, err + } + createUserMsg := &CreateUserMsg{CreateCredentialRequest: createUserRequest} + createUserMsg.BeginTimestamp = createUserMsg.GetBase().GetTimestamp() + createUserMsg.EndTimestamp = createUserMsg.GetBase().GetTimestamp() + return createUserMsg, nil +} + +func (c *CreateUserMsg) Size() int { + return proto.Size(c.CreateCredentialRequest) +} + +type UpdateUserMsg struct { + BaseMsg + *milvuspb.UpdateCredentialRequest +} + +var _ TsMsg = &UpdateUserMsg{} + +func (c *UpdateUserMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *UpdateUserMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *UpdateUserMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *UpdateUserMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *UpdateUserMsg) Marshal(input TsMsg) (MarshalType, error) { + updateUserMsg := input.(*UpdateUserMsg) + updateUserRequest := updateUserMsg.UpdateCredentialRequest + mb, err := proto.Marshal(updateUserRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *UpdateUserMsg) Unmarshal(input MarshalType) (TsMsg, error) { + updateUserRequest := &milvuspb.UpdateCredentialRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, updateUserRequest) + if err != nil { + return nil, err + } + updateUserMsg := &UpdateUserMsg{UpdateCredentialRequest: updateUserRequest} + updateUserMsg.BeginTimestamp = updateUserMsg.GetBase().GetTimestamp() + updateUserMsg.EndTimestamp = updateUserMsg.GetBase().GetTimestamp() + return updateUserMsg, nil +} + +func (c *UpdateUserMsg) Size() int { + return proto.Size(c.UpdateCredentialRequest) +} + +type DeleteUserMsg struct { + BaseMsg + *milvuspb.DeleteCredentialRequest +} + +var _ TsMsg = &DeleteUserMsg{} + +func (c *DeleteUserMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *DeleteUserMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *DeleteUserMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *DeleteUserMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *DeleteUserMsg) Marshal(input TsMsg) (MarshalType, error) { + deleteUserMsg := input.(*DeleteUserMsg) + deleteUserRequest := deleteUserMsg.DeleteCredentialRequest + mb, err := proto.Marshal(deleteUserRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *DeleteUserMsg) Unmarshal(input MarshalType) (TsMsg, error) { + deleteUserRequest := &milvuspb.DeleteCredentialRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, deleteUserRequest) + if err != nil { + return nil, err + } + deleteUserMsg := &DeleteUserMsg{DeleteCredentialRequest: deleteUserRequest} + deleteUserMsg.BeginTimestamp = deleteUserMsg.GetBase().GetTimestamp() + deleteUserMsg.EndTimestamp = deleteUserMsg.GetBase().GetTimestamp() + return deleteUserMsg, nil +} + +func (c *DeleteUserMsg) Size() int { + return proto.Size(c.DeleteCredentialRequest) +} + +type CreateRoleMsg struct { + BaseMsg + *milvuspb.CreateRoleRequest +} + +var _ TsMsg = &CreateRoleMsg{} + +func (c *CreateRoleMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *CreateRoleMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *CreateRoleMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *CreateRoleMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *CreateRoleMsg) Marshal(input TsMsg) (MarshalType, error) { + createRoleMsg := input.(*CreateRoleMsg) + createRoleRequest := createRoleMsg.CreateRoleRequest + mb, err := proto.Marshal(createRoleRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *CreateRoleMsg) Unmarshal(input MarshalType) (TsMsg, error) { + createRoleRequest := &milvuspb.CreateRoleRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, createRoleRequest) + if err != nil { + return nil, err + } + createRoleMsg := &CreateRoleMsg{CreateRoleRequest: createRoleRequest} + createRoleMsg.BeginTimestamp = createRoleMsg.GetBase().GetTimestamp() + createRoleMsg.EndTimestamp = createRoleMsg.GetBase().GetTimestamp() + return createRoleMsg, nil +} + +func (c *CreateRoleMsg) Size() int { + return proto.Size(c.CreateRoleRequest) +} + +type DropRoleMsg struct { + BaseMsg + *milvuspb.DropRoleRequest +} + +var _ TsMsg = &DropRoleMsg{} + +func (c *DropRoleMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *DropRoleMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *DropRoleMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *DropRoleMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *DropRoleMsg) Marshal(input TsMsg) (MarshalType, error) { + dropRoleMsg := input.(*DropRoleMsg) + dropRoleRequest := dropRoleMsg.DropRoleRequest + mb, err := proto.Marshal(dropRoleRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *DropRoleMsg) Unmarshal(input MarshalType) (TsMsg, error) { + dropRoleRequest := &milvuspb.DropRoleRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, dropRoleRequest) + if err != nil { + return nil, err + } + dropRoleMsg := &DropRoleMsg{DropRoleRequest: dropRoleRequest} + dropRoleMsg.BeginTimestamp = dropRoleMsg.GetBase().GetTimestamp() + dropRoleMsg.EndTimestamp = dropRoleMsg.GetBase().GetTimestamp() + return dropRoleMsg, nil +} + +func (c *DropRoleMsg) Size() int { + return proto.Size(c.DropRoleRequest) +} + +type OperateUserRoleMsg struct { + BaseMsg + *milvuspb.OperateUserRoleRequest +} + +var _ TsMsg = &OperateUserRoleMsg{} + +func (c *OperateUserRoleMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *OperateUserRoleMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *OperateUserRoleMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *OperateUserRoleMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *OperateUserRoleMsg) Marshal(input TsMsg) (MarshalType, error) { + operateUserRoleMsg := input.(*OperateUserRoleMsg) + operateUserRoleRequest := operateUserRoleMsg.OperateUserRoleRequest + mb, err := proto.Marshal(operateUserRoleRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *OperateUserRoleMsg) Unmarshal(input MarshalType) (TsMsg, error) { + operateUserRoleRequest := &milvuspb.OperateUserRoleRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, operateUserRoleRequest) + if err != nil { + return nil, err + } + operateUserRoleMsg := &OperateUserRoleMsg{OperateUserRoleRequest: operateUserRoleRequest} + operateUserRoleMsg.BeginTimestamp = operateUserRoleMsg.GetBase().GetTimestamp() + operateUserRoleMsg.EndTimestamp = operateUserRoleMsg.GetBase().GetTimestamp() + return operateUserRoleMsg, nil +} + +func (c *OperateUserRoleMsg) Size() int { + return proto.Size(c.OperateUserRoleRequest) +} + +type OperatePrivilegeMsg struct { + BaseMsg + *milvuspb.OperatePrivilegeRequest +} + +var _ TsMsg = &OperatePrivilegeMsg{} + +func (c *OperatePrivilegeMsg) ID() UniqueID { + return c.Base.MsgID +} + +func (c *OperatePrivilegeMsg) SetID(id UniqueID) { + c.Base.MsgID = id +} + +func (c *OperatePrivilegeMsg) Type() MsgType { + return c.Base.MsgType +} + +func (c *OperatePrivilegeMsg) SourceID() int64 { + return c.Base.SourceID +} + +func (c *OperatePrivilegeMsg) Marshal(input TsMsg) (MarshalType, error) { + operatePrivilegeMsg := input.(*OperatePrivilegeMsg) + operatePrivilegeRequest := operatePrivilegeMsg.OperatePrivilegeRequest + mb, err := proto.Marshal(operatePrivilegeRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (c *OperatePrivilegeMsg) Unmarshal(input MarshalType) (TsMsg, error) { + operatePrivilegeRequest := &milvuspb.OperatePrivilegeRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, operatePrivilegeRequest) + if err != nil { + return nil, err + } + operatePrivilegeMsg := &OperatePrivilegeMsg{OperatePrivilegeRequest: operatePrivilegeRequest} + operatePrivilegeMsg.BeginTimestamp = operatePrivilegeMsg.GetBase().GetTimestamp() + operatePrivilegeMsg.EndTimestamp = operatePrivilegeMsg.GetBase().GetTimestamp() + return operatePrivilegeMsg, nil +} + +func (c *OperatePrivilegeMsg) Size() int { + return proto.Size(c.OperatePrivilegeRequest) +} diff --git a/pkg/mq/msgstream/msg_for_user_role_test.go b/pkg/mq/msgstream/msg_for_user_role_test.go new file mode 100644 index 0000000000..0d928107bb --- /dev/null +++ b/pkg/mq/msgstream/msg_for_user_role_test.go @@ -0,0 +1,314 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +func TestCreateUser(t *testing.T) { + var msg TsMsg = &CreateUserMsg{ + CreateCredentialRequest: &milvuspb.CreateCredentialRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateCredential, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + Username: "unit_user", + Password: "unit_password", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_CreateCredential, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &CreateUserMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_user", newMsg.(*CreateUserMsg).Username) + assert.EqualValues(t, "unit_password", newMsg.(*CreateUserMsg).Password) + + assert.True(t, msg.Size() > 0) +} + +func TestUpdateUser(t *testing.T) { + var msg TsMsg = &UpdateUserMsg{ + UpdateCredentialRequest: &milvuspb.UpdateCredentialRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_UpdateCredential, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + Username: "unit_user", + OldPassword: "unit_old_password", + NewPassword: "unit_new_password", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_UpdateCredential, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &UpdateUserMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_user", newMsg.(*UpdateUserMsg).Username) + assert.EqualValues(t, "unit_old_password", newMsg.(*UpdateUserMsg).OldPassword) + assert.EqualValues(t, "unit_new_password", newMsg.(*UpdateUserMsg).NewPassword) + + assert.True(t, msg.Size() > 0) +} + +func TestDeleteUser(t *testing.T) { + var msg TsMsg = &DeleteUserMsg{ + DeleteCredentialRequest: &milvuspb.DeleteCredentialRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DeleteCredential, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + Username: "unit_user", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_DeleteCredential, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &DeleteUserMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_user", newMsg.(*DeleteUserMsg).Username) + + assert.True(t, msg.Size() > 0) +} + +func TestCreateRole(t *testing.T) { + var msg TsMsg = &CreateRoleMsg{ + CreateRoleRequest: &milvuspb.CreateRoleRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateRole, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + Entity: &milvuspb.RoleEntity{ + Name: "unit_role", + }, + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_CreateRole, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &CreateRoleMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_role", newMsg.(*CreateRoleMsg).GetEntity().GetName()) + + assert.True(t, msg.Size() > 0) +} + +func TestDropRole(t *testing.T) { + var msg TsMsg = &DropRoleMsg{ + DropRoleRequest: &milvuspb.DropRoleRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropRole, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + RoleName: "unit_role", + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_DropRole, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &DropRoleMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_role", newMsg.(*DropRoleMsg).GetRoleName()) + + assert.True(t, msg.Size() > 0) +} + +func TestOperateUserRole(t *testing.T) { + var msg TsMsg = &OperateUserRoleMsg{ + OperateUserRoleRequest: &milvuspb.OperateUserRoleRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_OperateUserRole, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + RoleName: "unit_role", + Username: "unit_user", + Type: milvuspb.OperateUserRoleType_AddUserToRole, + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_OperateUserRole, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &OperateUserRoleMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_role", newMsg.(*OperateUserRoleMsg).GetRoleName()) + assert.EqualValues(t, "unit_user", newMsg.(*OperateUserRoleMsg).GetUsername()) + assert.EqualValues(t, milvuspb.OperateUserRoleType_AddUserToRole, newMsg.(*OperateUserRoleMsg).GetType()) + + assert.True(t, msg.Size() > 0) +} + +func TestOperatePrivilege(t *testing.T) { + var msg TsMsg = &OperatePrivilegeMsg{ + OperatePrivilegeRequest: &milvuspb.OperatePrivilegeRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_OperatePrivilege, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + Entity: &milvuspb.GrantEntity{ + Role: &milvuspb.RoleEntity{Name: "unit_role"}, + Object: &milvuspb.ObjectEntity{Name: "Collection"}, + ObjectName: "col1", + Grantor: &milvuspb.GrantorEntity{ + User: &milvuspb.UserEntity{Name: "unit_user"}, + Privilege: &milvuspb.PrivilegeEntity{Name: "unit_privilege"}, + }, + DbName: "unit_db", + }, + Type: milvuspb.OperatePrivilegeType_Grant, + }, + } + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_OperatePrivilege, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &OperatePrivilegeMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_role", newMsg.(*OperatePrivilegeMsg).GetEntity().GetRole().GetName()) + assert.EqualValues(t, "Collection", newMsg.(*OperatePrivilegeMsg).GetEntity().GetObject().GetName()) + assert.EqualValues(t, "col1", newMsg.(*OperatePrivilegeMsg).GetEntity().GetObjectName()) + assert.EqualValues(t, "unit_user", newMsg.(*OperatePrivilegeMsg).GetEntity().GetGrantor().GetUser().GetName()) + assert.EqualValues(t, "unit_privilege", newMsg.(*OperatePrivilegeMsg).GetEntity().GetGrantor().GetPrivilege().GetName()) + assert.EqualValues(t, milvuspb.OperatePrivilegeType_Grant, newMsg.(*OperatePrivilegeMsg).GetType()) + + assert.True(t, msg.Size() > 0) +}