mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: remove old arch non-streaming arch code (#43651)
issue: #41609 - remove all dml dead code at proxy - remove dead code at l0_write_buffer - remove msgstream dependency at proxy - remove timetick reporter from proxy - remove replicate stream implementation --------- Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
6ae727775f
commit
5551d99425
@ -146,9 +146,7 @@ func GetMilvusRoles(args []string, flags *flag.FlagSet) *roles.MilvusRoles {
|
||||
role.EnableProxy = true
|
||||
role.EnableQueryNode = true
|
||||
role.EnableDataNode = true
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
role.EnableStreamingNode = true
|
||||
}
|
||||
role.EnableStreamingNode = true
|
||||
role.Local = true
|
||||
role.Embedded = serverType == typeutil.EmbeddedRole
|
||||
case typeutil.MixCoordRole:
|
||||
|
||||
@ -12,6 +12,7 @@ packages:
|
||||
Utility:
|
||||
Broadcast:
|
||||
Local:
|
||||
Scanner:
|
||||
github.com/milvus-io/milvus/internal/streamingcoord/server/balancer:
|
||||
interfaces:
|
||||
Balancer:
|
||||
|
||||
@ -1462,6 +1462,9 @@ func (s *Server) GetFlushState(ctx context.Context, req *datapb.GetFlushStateReq
|
||||
for _, sid := range req.GetSegmentIDs() {
|
||||
segment := s.meta.GetHealthySegment(ctx, sid)
|
||||
// segment is nil if it was compacted, or it's an empty segment and is set to dropped
|
||||
// TODO: Here's a dirty implementation, because a growing segment may cannot be seen right away by mixcoord,
|
||||
// it can only be seen by streamingnode right away, so we need to check the flush state at streamingnode but not here.
|
||||
// use timetick for GetFlushState in-future but not segment list.
|
||||
if segment == nil || isFlushState(segment.GetState()) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -32,13 +32,10 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
dn "github.com/milvus-io/milvus/internal/datanode"
|
||||
mix "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
||||
"github.com/milvus-io/milvus/internal/distributed/utils"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/componentutil"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
_ "github.com/milvus-io/milvus/internal/util/grpcclient"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
@ -65,8 +62,6 @@ type Server struct {
|
||||
factory dependency.Factory
|
||||
|
||||
serverID atomic.Int64
|
||||
|
||||
mixCoordClient func() (types.MixCoordClient, error)
|
||||
}
|
||||
|
||||
// NewServer new DataNode grpc server
|
||||
@ -77,9 +72,6 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
|
||||
cancel: cancel,
|
||||
factory: factory,
|
||||
grpcErrChan: make(chan error),
|
||||
mixCoordClient: func() (types.MixCoordClient, error) {
|
||||
return mix.NewClient(ctx1)
|
||||
},
|
||||
}
|
||||
|
||||
s.serverID.Store(paramtable.GetNodeID())
|
||||
@ -173,10 +165,6 @@ func (s *Server) SetEtcdClient(client *clientv3.Client) {
|
||||
s.datanode.SetEtcdClient(client)
|
||||
}
|
||||
|
||||
func (s *Server) SetMixCoordInterface(ms types.MixCoordClient) error {
|
||||
return s.datanode.SetMixCoordClient(ms)
|
||||
}
|
||||
|
||||
// Run initializes and starts Datanode's grpc service.
|
||||
func (s *Server) Run() error {
|
||||
if err := s.init(); err != nil {
|
||||
@ -255,27 +243,6 @@ func (s *Server) init() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if !streamingutil.IsStreamingServiceEnabled() {
|
||||
// --- MixCoord Client ---
|
||||
if s.mixCoordClient != nil {
|
||||
log.Info("initializing MixCoord client for DataNode")
|
||||
mixCoordClient, err := s.mixCoordClient()
|
||||
if err != nil {
|
||||
log.Error("failed to create new MixCoord client", zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err = componentutil.WaitForComponentHealthy(s.ctx, mixCoordClient, "MixCoord", 1000000, time.Millisecond*200); err != nil {
|
||||
log.Error("failed to wait for MixCoord client to be ready", zap.Error(err))
|
||||
panic(err)
|
||||
}
|
||||
log.Info("MixCoord client is ready for DataNode")
|
||||
if err = s.SetMixCoordInterface(mixCoordClient); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.datanode.UpdateStateCode(commonpb.StateCode_Initializing)
|
||||
|
||||
if err := s.datanode.Init(); err != nil {
|
||||
|
||||
@ -27,7 +27,6 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
|
||||
@ -43,27 +42,10 @@ func Test_NewServer(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, server)
|
||||
|
||||
mockMixCoord := mocks.NewMockMixCoordClient(t)
|
||||
mockMixCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
|
||||
State: &milvuspb.ComponentInfo{
|
||||
StateCode: commonpb.StateCode_Healthy,
|
||||
},
|
||||
Status: merr.Success(),
|
||||
SubcomponentStates: []*milvuspb.ComponentInfo{
|
||||
{
|
||||
StateCode: commonpb.StateCode_Healthy,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
server.mixCoordClient = func() (types.MixCoordClient, error) {
|
||||
return mockMixCoord, nil
|
||||
}
|
||||
|
||||
t.Run("Run", func(t *testing.T) {
|
||||
datanode := mocks.NewMockDataNode(t)
|
||||
datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
datanode.EXPECT().SetAddress(mock.Anything).Return()
|
||||
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
|
||||
datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
datanode.EXPECT().Register().Return(nil)
|
||||
datanode.EXPECT().Init().Return(nil)
|
||||
@ -191,26 +173,9 @@ func Test_Run(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, server)
|
||||
|
||||
mockRootCoord := mocks.NewMockMixCoordClient(t)
|
||||
mockRootCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
|
||||
State: &milvuspb.ComponentInfo{
|
||||
StateCode: commonpb.StateCode_Healthy,
|
||||
},
|
||||
Status: merr.Success(),
|
||||
SubcomponentStates: []*milvuspb.ComponentInfo{
|
||||
{
|
||||
StateCode: commonpb.StateCode_Healthy,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
server.mixCoordClient = func() (types.MixCoordClient, error) {
|
||||
return mockRootCoord, nil
|
||||
}
|
||||
|
||||
datanode := mocks.NewMockDataNode(t)
|
||||
datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
datanode.EXPECT().SetAddress(mock.Anything).Return()
|
||||
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
|
||||
datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
datanode.EXPECT().Init().Return(errors.New("mock err"))
|
||||
server.datanode = datanode
|
||||
@ -223,7 +188,6 @@ func Test_Run(t *testing.T) {
|
||||
datanode = mocks.NewMockDataNode(t)
|
||||
datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
datanode.EXPECT().SetAddress(mock.Anything).Return()
|
||||
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
|
||||
datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
datanode.EXPECT().Register().Return(nil)
|
||||
datanode.EXPECT().Init().Return(nil)
|
||||
@ -242,26 +206,9 @@ func TestIndexService(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, server)
|
||||
|
||||
mockRootCoord := mocks.NewMockMixCoordClient(t)
|
||||
mockRootCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
|
||||
State: &milvuspb.ComponentInfo{
|
||||
StateCode: commonpb.StateCode_Healthy,
|
||||
},
|
||||
Status: merr.Success(),
|
||||
SubcomponentStates: []*milvuspb.ComponentInfo{
|
||||
{
|
||||
StateCode: commonpb.StateCode_Healthy,
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
server.mixCoordClient = func() (types.MixCoordClient, error) {
|
||||
return mockRootCoord, nil
|
||||
}
|
||||
|
||||
dn := mocks.NewMockDataNode(t)
|
||||
dn.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
dn.EXPECT().SetAddress(mock.Anything).Return()
|
||||
dn.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
|
||||
dn.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
dn.EXPECT().Register().Return(nil)
|
||||
dn.EXPECT().Init().Return(nil)
|
||||
|
||||
@ -449,7 +449,6 @@ func (s *Server) init() error {
|
||||
return err
|
||||
}
|
||||
s.etcdCli = etcdCli
|
||||
s.proxy.SetEtcdClient(s.etcdCli)
|
||||
s.proxy.SetAddress(s.listenerManager.internalGrpcListener.Address())
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
@ -201,7 +201,6 @@ func Test_NewServer(t *testing.T) {
|
||||
mockProxy.EXPECT().Init().Return(nil)
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
@ -687,7 +686,6 @@ func Test_NewServer(t *testing.T) {
|
||||
mockProxy.EXPECT().Init().Return(nil)
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
@ -821,7 +819,6 @@ func Test_NewServer_HTTPServer_Enabled(t *testing.T) {
|
||||
mockProxy.EXPECT().Init().Return(nil)
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
@ -886,7 +883,6 @@ func Test_NewServer_TLS_TwoWay(t *testing.T) {
|
||||
mockProxy.EXPECT().Init().Return(nil)
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
@ -914,7 +910,6 @@ func Test_NewServer_TLS_OneWay(t *testing.T) {
|
||||
mockProxy.EXPECT().Init().Return(nil)
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
@ -938,7 +933,6 @@ func Test_NewServer_TLS_FileNotExisted(t *testing.T) {
|
||||
|
||||
mockProxy := server.proxy.(*mocks.MockProxy)
|
||||
mockProxy.EXPECT().Stop().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetAddress(mock.Anything).Return()
|
||||
|
||||
@ -977,7 +971,6 @@ func Test_NewHTTPServer_TLS_TwoWay(t *testing.T) {
|
||||
mockProxy.EXPECT().Init().Return(nil)
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
@ -1013,7 +1006,6 @@ func Test_NewHTTPServer_TLS_OneWay(t *testing.T) {
|
||||
mockProxy.EXPECT().Init().Return(nil)
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
@ -1046,7 +1038,6 @@ func Test_NewHTTPServer_TLS_FileNotExisted(t *testing.T) {
|
||||
|
||||
mockProxy := server.proxy.(*mocks.MockProxy)
|
||||
mockProxy.EXPECT().Stop().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return().Maybe()
|
||||
mockProxy.EXPECT().SetAddress(mock.Anything).Return().Maybe()
|
||||
Params := ¶mtable.Get().ProxyGrpcServerCfg
|
||||
|
||||
@ -1160,7 +1151,6 @@ func Test_Service_GracefulStop(t *testing.T) {
|
||||
mockProxy.EXPECT().Start().Return(nil)
|
||||
mockProxy.EXPECT().Stop().Return(nil)
|
||||
mockProxy.EXPECT().Register().Return(nil)
|
||||
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
|
||||
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
|
||||
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
|
||||
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
|
||||
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/options"
|
||||
@ -18,8 +17,6 @@ var singleton WALAccesser = nil
|
||||
func Init() {
|
||||
c, _ := kvfactory.GetEtcdAndPath()
|
||||
singleton = newWALAccesser(c)
|
||||
// Add the wal accesser to the broadcaster registry for making broadcast operation.
|
||||
registry.Register(registry.AppendOperatorTypeStreaming, singleton)
|
||||
}
|
||||
|
||||
// Release releases the resources of the wal accesser.
|
||||
|
||||
@ -18,7 +18,20 @@
|
||||
|
||||
package streaming
|
||||
|
||||
import kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
|
||||
)
|
||||
|
||||
var expectErr = make(chan error, 10)
|
||||
|
||||
// SetWALForTest initializes the singleton of wal for test.
|
||||
func SetWALForTest(w WALAccesser) {
|
||||
@ -29,3 +42,146 @@ func RecoverWALForTest() {
|
||||
c, _ := kvfactory.GetEtcdAndPath()
|
||||
singleton = newWALAccesser(c)
|
||||
}
|
||||
|
||||
func ExpectErrorOnce(err error) {
|
||||
expectErr <- err
|
||||
}
|
||||
|
||||
func SetupNoopWALForTest() {
|
||||
singleton = &noopWALAccesser{}
|
||||
}
|
||||
|
||||
type noopLocal struct{}
|
||||
|
||||
func (n *noopLocal) GetLatestMVCCTimestampIfLocal(ctx context.Context, vchannel string) (uint64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (n *noopLocal) GetMetricsIfLocal(ctx context.Context) (*types.StreamingNodeMetrics, error) {
|
||||
return &types.StreamingNodeMetrics{}, nil
|
||||
}
|
||||
|
||||
type noopBroadcast struct{}
|
||||
|
||||
func (n *noopBroadcast) Append(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
|
||||
if err := getExpectErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.BroadcastAppendResult{
|
||||
BroadcastID: 1,
|
||||
AppendResults: map[string]*types.AppendResult{
|
||||
"v1": {
|
||||
MessageID: rmq.NewRmqID(1),
|
||||
TimeTick: 10,
|
||||
Extra: &anypb.Any{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *noopBroadcast) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopTxn struct{}
|
||||
|
||||
func (n *noopTxn) Append(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) error {
|
||||
if err := getExpectErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopTxn) Commit(ctx context.Context) (*types.AppendResult, error) {
|
||||
if err := getExpectErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &types.AppendResult{}, nil
|
||||
}
|
||||
|
||||
func (n *noopTxn) Rollback(ctx context.Context) error {
|
||||
if err := getExpectErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopWALAccesser struct{}
|
||||
|
||||
func (n *noopWALAccesser) WALName() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) Local() Local {
|
||||
return &noopLocal{}
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) Txn(ctx context.Context, opts TxnOption) (Txn, error) {
|
||||
if err := getExpectErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &noopTxn{}, nil
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) {
|
||||
if err := getExpectErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
extra, _ := anypb.New(&messagespb.ManualFlushExtraResponse{
|
||||
SegmentIds: []int64{1, 2, 3},
|
||||
})
|
||||
return &types.AppendResult{
|
||||
MessageID: rmq.NewRmqID(1),
|
||||
TimeTick: 10,
|
||||
Extra: extra,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) Broadcast() Broadcast {
|
||||
return &noopBroadcast{}
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) Read(ctx context.Context, opts ReadOption) Scanner {
|
||||
return &noopScanner{}
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses {
|
||||
if err := getExpectErr(); err != nil {
|
||||
return AppendResponses{
|
||||
Responses: []AppendResponse{
|
||||
{
|
||||
AppendResult: nil,
|
||||
Error: err,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return AppendResponses{}
|
||||
}
|
||||
|
||||
func (n *noopWALAccesser) AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses {
|
||||
return AppendResponses{}
|
||||
}
|
||||
|
||||
type noopScanner struct{}
|
||||
|
||||
func (n *noopScanner) Done() <-chan struct{} {
|
||||
return make(chan struct{})
|
||||
}
|
||||
|
||||
func (n *noopScanner) Error() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopScanner) Close() {
|
||||
}
|
||||
|
||||
// getExpectErr is a helper function to get the error from the expectErr channel.
|
||||
func getExpectErr() error {
|
||||
select {
|
||||
case err := <-expectErr:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming/internal/producer"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/client"
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/client/handler"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
@ -29,11 +28,7 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl {
|
||||
// Create a new streaming coord client.
|
||||
streamingCoordClient := client.NewClient(c)
|
||||
// Create a new streamingnode handler client.
|
||||
var handlerClient handler.HandlerClient
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
// streaming service is enabled, create the handler client for the streaming service.
|
||||
handlerClient = handler.NewHandlerClient(streamingCoordClient.Assignment())
|
||||
}
|
||||
handlerClient := handler.NewHandlerClient(streamingCoordClient.Assignment())
|
||||
w := &walAccesserImpl{
|
||||
lifetime: typeutil.NewLifetime(),
|
||||
streamingCoordClient: streamingCoordClient,
|
||||
|
||||
@ -4,27 +4,21 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/flushcommon/io"
|
||||
"github.com/milvus-io/milvus/internal/flushcommon/metacache"
|
||||
"github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/flushcommon/syncmgr"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type l0WriteBuffer struct {
|
||||
@ -54,94 +48,6 @@ func NewL0WriteBuffer(channel string, metacache metacache.MetaCache, syncMgr syn
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (wb *l0WriteBuffer) dispatchDeleteMsgs(groups []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
|
||||
batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt()
|
||||
split := func(pks []storage.PrimaryKey, pkTss []uint64, partitionSegments []*metacache.SegmentInfo, partitionGroups []*InsertData) []bool {
|
||||
lc := storage.NewBatchLocationsCache(pks)
|
||||
|
||||
// use hits to cache result
|
||||
hits := make([]bool, len(pks))
|
||||
|
||||
for _, segment := range partitionSegments {
|
||||
hits = segment.GetBloomFilterSet().BatchPkExistWithHits(lc, hits)
|
||||
}
|
||||
|
||||
for _, inData := range partitionGroups {
|
||||
hits = inData.batchPkExists(pks, pkTss, hits)
|
||||
}
|
||||
|
||||
return hits
|
||||
}
|
||||
|
||||
type BatchApplyRet = struct {
|
||||
// represent the idx for delete msg in deleteMsgs
|
||||
DeleteDataIdx int
|
||||
// represent the start idx for the batch in each deleteMsg
|
||||
StartIdx int
|
||||
Hits []bool
|
||||
}
|
||||
|
||||
// transform pk to primary key
|
||||
pksInDeleteMsgs := lo.Map(deleteMsgs, func(delMsg *msgstream.DeleteMsg, _ int) []storage.PrimaryKey {
|
||||
return storage.ParseIDs2PrimaryKeys(delMsg.GetPrimaryKeys())
|
||||
})
|
||||
|
||||
retIdx := 0
|
||||
retMap := typeutil.NewConcurrentMap[int, *BatchApplyRet]()
|
||||
pool := io.GetBFApplyPool()
|
||||
var futures []*conc.Future[any]
|
||||
for didx, delMsg := range deleteMsgs {
|
||||
pks := pksInDeleteMsgs[didx]
|
||||
pkTss := delMsg.GetTimestamps()
|
||||
partitionSegments := wb.metaCache.GetSegmentsBy(metacache.WithPartitionID(delMsg.PartitionID),
|
||||
metacache.WithSegmentState(commonpb.SegmentState_Growing, commonpb.SegmentState_Sealed, commonpb.SegmentState_Flushing, commonpb.SegmentState_Flushed))
|
||||
partitionGroups := lo.Filter(groups, func(inData *InsertData, _ int) bool {
|
||||
return delMsg.GetPartitionID() == common.AllPartitionsID || delMsg.GetPartitionID() == inData.partitionID
|
||||
})
|
||||
|
||||
for idx := 0; idx < len(pks); idx += batchSize {
|
||||
startIdx := idx
|
||||
endIdx := idx + batchSize
|
||||
if endIdx > len(pks) {
|
||||
endIdx = len(pks)
|
||||
}
|
||||
retIdx += 1
|
||||
tmpRetIdx := retIdx
|
||||
deleteDataId := didx
|
||||
future := pool.Submit(func() (any, error) {
|
||||
hits := split(pks[startIdx:endIdx], pkTss[startIdx:endIdx], partitionSegments, partitionGroups)
|
||||
retMap.Insert(tmpRetIdx, &BatchApplyRet{
|
||||
DeleteDataIdx: deleteDataId,
|
||||
StartIdx: startIdx,
|
||||
Hits: hits,
|
||||
})
|
||||
return nil, nil
|
||||
})
|
||||
futures = append(futures, future)
|
||||
}
|
||||
}
|
||||
conc.AwaitAll(futures...)
|
||||
|
||||
retMap.Range(func(key int, value *BatchApplyRet) bool {
|
||||
l0SegmentID := wb.getL0SegmentID(deleteMsgs[value.DeleteDataIdx].GetPartitionID(), startPos)
|
||||
pks := pksInDeleteMsgs[value.DeleteDataIdx]
|
||||
pkTss := deleteMsgs[value.DeleteDataIdx].GetTimestamps()
|
||||
|
||||
var deletePks []storage.PrimaryKey
|
||||
var deleteTss []typeutil.Timestamp
|
||||
for i, hit := range value.Hits {
|
||||
if hit {
|
||||
deletePks = append(deletePks, pks[value.StartIdx+i])
|
||||
deleteTss = append(deleteTss, pkTss[value.StartIdx+i])
|
||||
}
|
||||
}
|
||||
if len(deletePks) > 0 {
|
||||
wb.bufferDelete(l0SegmentID, deletePks, deleteTss, startPos, endPos)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (wb *l0WriteBuffer) dispatchDeleteMsgsWithoutFilter(deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
|
||||
for _, msg := range deleteMsgs {
|
||||
l0SegmentID := wb.getL0SegmentID(msg.GetPartitionID(), startPos)
|
||||
@ -165,31 +71,10 @@ func (wb *l0WriteBuffer) BufferData(insertData []*InsertData, deleteMsgs []*msgs
|
||||
}
|
||||
}
|
||||
|
||||
if paramtable.Get().DataNodeCfg.SkipBFStatsLoad.GetAsBool() || streamingutil.IsStreamingServiceEnabled() {
|
||||
// In streaming service mode, flushed segments no longer maintain a bloom filter.
|
||||
// So, here we skip generating BF (growing segment's BF will be regenerated during the sync phase)
|
||||
// and also skip filtering delete entries by bf.
|
||||
wb.dispatchDeleteMsgsWithoutFilter(deleteMsgs, startPos, endPos)
|
||||
} else {
|
||||
// distribute delete msg
|
||||
// bf write buffer check bloom filter of segment and current insert batch to decide which segment to write delete data
|
||||
wb.dispatchDeleteMsgs(insertData, deleteMsgs, startPos, endPos)
|
||||
|
||||
// update pk oracle
|
||||
for _, inData := range insertData {
|
||||
// segment shall always exists after buffer insert
|
||||
segments := wb.metaCache.GetSegmentsBy(metacache.WithSegmentIDs(inData.segmentID))
|
||||
for _, segment := range segments {
|
||||
for _, fieldData := range inData.pkField {
|
||||
err := segment.GetBloomFilterSet().UpdatePKRange(fieldData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In streaming service mode, flushed segments no longer maintain a bloom filter.
|
||||
// So, here we skip generating BF (growing segment's BF will be regenerated during the sync phase)
|
||||
// and also skip filtering delete entries by bf.
|
||||
wb.dispatchDeleteMsgsWithoutFilter(deleteMsgs, startPos, endPos)
|
||||
// update buffer last checkpoint
|
||||
wb.checkpoint = endPos
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/flushcommon/syncmgr"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
@ -324,10 +323,8 @@ func (wb *writeBufferBase) syncSegments(ctx context.Context, segmentIDs []int64)
|
||||
}
|
||||
|
||||
if syncTask.IsFlush() {
|
||||
if paramtable.Get().DataNodeCfg.SkipBFStatsLoad.GetAsBool() || streamingutil.IsStreamingServiceEnabled() {
|
||||
wb.metaCache.RemoveSegments(metacache.WithSegmentIDs(syncTask.SegmentID()))
|
||||
log.Info("flushed segment removed", zap.Int64("segmentID", syncTask.SegmentID()), zap.String("channel", syncTask.ChannelName()))
|
||||
}
|
||||
wb.metaCache.RemoveSegments(metacache.WithSegmentIDs(syncTask.SegmentID()))
|
||||
log.Info("flushed segment removed", zap.Int64("segmentID", syncTask.SegmentID()), zap.String("channel", syncTask.ChannelName()))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
156
internal/mocks/distributed/mock_streaming/mock_Scanner.go
Normal file
156
internal/mocks/distributed/mock_streaming/mock_Scanner.go
Normal file
@ -0,0 +1,156 @@
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mock_streaming
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockScanner is an autogenerated mock type for the Scanner type
|
||||
type MockScanner struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockScanner_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockScanner) EXPECT() *MockScanner_Expecter {
|
||||
return &MockScanner_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Close provides a mock function with no fields
|
||||
func (_m *MockScanner) Close() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockScanner_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
|
||||
type MockScanner_Close_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Close is a helper method to define mock.On call
|
||||
func (_e *MockScanner_Expecter) Close() *MockScanner_Close_Call {
|
||||
return &MockScanner_Close_Call{Call: _e.mock.On("Close")}
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Close_Call) Run(run func()) *MockScanner_Close_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Close_Call) Return() *MockScanner_Close_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Close_Call) RunAndReturn(run func()) *MockScanner_Close_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Done provides a mock function with no fields
|
||||
func (_m *MockScanner) Done() <-chan struct{} {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Done")
|
||||
}
|
||||
|
||||
var r0 <-chan struct{}
|
||||
if rf, ok := ret.Get(0).(func() <-chan struct{}); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(<-chan struct{})
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockScanner_Done_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Done'
|
||||
type MockScanner_Done_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Done is a helper method to define mock.On call
|
||||
func (_e *MockScanner_Expecter) Done() *MockScanner_Done_Call {
|
||||
return &MockScanner_Done_Call{Call: _e.mock.On("Done")}
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Done_Call) Run(run func()) *MockScanner_Done_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Done_Call) Return(_a0 <-chan struct{}) *MockScanner_Done_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Done_Call) RunAndReturn(run func() <-chan struct{}) *MockScanner_Done_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Error provides a mock function with no fields
|
||||
func (_m *MockScanner) Error() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Error")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockScanner_Error_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Error'
|
||||
type MockScanner_Error_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Error is a helper method to define mock.On call
|
||||
func (_e *MockScanner_Expecter) Error() *MockScanner_Error_Call {
|
||||
return &MockScanner_Error_Call{Call: _e.mock.On("Error")}
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Error_Call) Run(run func()) *MockScanner_Error_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Error_Call) Return(_a0 error) *MockScanner_Error_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockScanner_Error_Call) RunAndReturn(run func() error) *MockScanner_Error_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockScanner creates a new instance of MockScanner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockScanner(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockScanner {
|
||||
mock := &MockScanner{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@ -16,8 +16,6 @@ import (
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
types "github.com/milvus-io/milvus/internal/types"
|
||||
|
||||
workerpb "github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
|
||||
)
|
||||
|
||||
@ -1918,52 +1916,6 @@ func (_c *MockDataNode_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clien
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetMixCoordClient provides a mock function with given fields: mixCoord
|
||||
func (_m *MockDataNode) SetMixCoordClient(mixCoord types.MixCoordClient) error {
|
||||
ret := _m.Called(mixCoord)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SetMixCoordClient")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(types.MixCoordClient) error); ok {
|
||||
r0 = rf(mixCoord)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockDataNode_SetMixCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMixCoordClient'
|
||||
type MockDataNode_SetMixCoordClient_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SetMixCoordClient is a helper method to define mock.On call
|
||||
// - mixCoord types.MixCoordClient
|
||||
func (_e *MockDataNode_Expecter) SetMixCoordClient(mixCoord interface{}) *MockDataNode_SetMixCoordClient_Call {
|
||||
return &MockDataNode_SetMixCoordClient_Call{Call: _e.mock.On("SetMixCoordClient", mixCoord)}
|
||||
}
|
||||
|
||||
func (_c *MockDataNode_SetMixCoordClient_Call) Run(run func(mixCoord types.MixCoordClient)) *MockDataNode_SetMixCoordClient_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(types.MixCoordClient))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockDataNode_SetMixCoordClient_Call) Return(_a0 error) *MockDataNode_SetMixCoordClient_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockDataNode_SetMixCoordClient_Call) RunAndReturn(run func(types.MixCoordClient) error) *MockDataNode_SetMixCoordClient_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ShowConfigurations provides a mock function with given fields: _a0, _a1
|
||||
func (_m *MockDataNode) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
|
||||
ret := _m.Called(_a0, _a1)
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
context "context"
|
||||
|
||||
commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
||||
federpb "github.com/milvus-io/milvus-proto/go-api/v2/federpb"
|
||||
|
||||
@ -6690,39 +6689,6 @@ func (_c *MockProxy_SetAddress_Call) RunAndReturn(run func(string)) *MockProxy_S
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetEtcdClient provides a mock function with given fields: etcdClient
|
||||
func (_m *MockProxy) SetEtcdClient(etcdClient *clientv3.Client) {
|
||||
_m.Called(etcdClient)
|
||||
}
|
||||
|
||||
// MockProxy_SetEtcdClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetEtcdClient'
|
||||
type MockProxy_SetEtcdClient_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SetEtcdClient is a helper method to define mock.On call
|
||||
// - etcdClient *clientv3.Client
|
||||
func (_e *MockProxy_Expecter) SetEtcdClient(etcdClient interface{}) *MockProxy_SetEtcdClient_Call {
|
||||
return &MockProxy_SetEtcdClient_Call{Call: _e.mock.On("SetEtcdClient", etcdClient)}
|
||||
}
|
||||
|
||||
func (_c *MockProxy_SetEtcdClient_Call) Run(run func(etcdClient *clientv3.Client)) *MockProxy_SetEtcdClient_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*clientv3.Client))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxy_SetEtcdClient_Call) Return() *MockProxy_SetEtcdClient_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockProxy_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Client)) *MockProxy_SetEtcdClient_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetMixCoordClient provides a mock function with given fields: rootCoord
|
||||
func (_m *MockProxy) SetMixCoordClient(rootCoord types.MixCoordClient) {
|
||||
_m.Called(rootCoord)
|
||||
|
||||
@ -39,9 +39,7 @@ import (
|
||||
type channelsMgr interface {
|
||||
getChannels(collectionID UniqueID) ([]pChan, error)
|
||||
getVChannels(collectionID UniqueID) ([]vChan, error)
|
||||
getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
|
||||
removeDMLStream(collectionID UniqueID)
|
||||
removeAllDMLStream()
|
||||
}
|
||||
|
||||
type channelInfos struct {
|
||||
@ -52,7 +50,6 @@ type channelInfos struct {
|
||||
|
||||
type streamInfos struct {
|
||||
channelInfos channelInfos
|
||||
stream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func removeDuplicate(ss []string) []string {
|
||||
@ -114,9 +111,8 @@ type singleTypeChannelsMgr struct {
|
||||
infos map[UniqueID]streamInfos // collection id -> stream infos
|
||||
mu sync.RWMutex
|
||||
|
||||
getChannelsFunc getChannelsFuncType
|
||||
repackFunc repackFuncType
|
||||
msgStreamFactory msgstream.Factory
|
||||
getChannelsFunc getChannelsFuncType
|
||||
repackFunc repackFuncType
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) getAllChannels(collectionID UniqueID) (channelInfos, error) {
|
||||
@ -167,27 +163,6 @@ func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan,
|
||||
return channelInfos.vchans, nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool {
|
||||
streamInfos, ok := mgr.infos[collectionID]
|
||||
return ok && streamInfos.stream != nil
|
||||
}
|
||||
|
||||
func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
|
||||
var stream msgstream.MsgStream
|
||||
var err error
|
||||
|
||||
stream, err = factory.NewMsgStream(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream.AsProducer(ctx, pchans)
|
||||
if repack != nil {
|
||||
stream.SetRepackFunc(repack)
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func incPChansMetrics(pchans []pChan) {
|
||||
for _, pc := range pchans {
|
||||
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), pc).Inc()
|
||||
@ -200,67 +175,6 @@ func decPChanMetrics(pchans []pChan) {
|
||||
}
|
||||
}
|
||||
|
||||
// createMsgStream create message stream for specified collection. Idempotent.
|
||||
// If stream already exists, directly return it and no error will be returned.
|
||||
func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
mgr.mu.RLock()
|
||||
infos, ok := mgr.infos[collectionID]
|
||||
if ok && infos.stream != nil {
|
||||
// already exist.
|
||||
mgr.mu.RUnlock()
|
||||
return infos.stream, nil
|
||||
}
|
||||
mgr.mu.RUnlock()
|
||||
|
||||
channelInfos, err := mgr.getChannelsFunc(collectionID)
|
||||
if err != nil {
|
||||
// What if stream created by other goroutines?
|
||||
log.Error("failed to get channels", zap.Error(err), zap.Int64("collection", collectionID))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
|
||||
if err != nil {
|
||||
// What if stream created by other goroutines?
|
||||
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
if !mgr.streamExistPrivate(collectionID) {
|
||||
log.Info("create message stream", zap.Int64("collection", collectionID),
|
||||
zap.Strings("virtual_channels", channelInfos.vchans),
|
||||
zap.Strings("physical_channels", channelInfos.pchans))
|
||||
mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream}
|
||||
incPChansMetrics(channelInfos.pchans)
|
||||
} else {
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
return mgr.infos[collectionID].stream, nil
|
||||
}
|
||||
|
||||
func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
streamInfos, ok := mgr.infos[collectionID]
|
||||
if ok {
|
||||
return streamInfos.stream, nil
|
||||
}
|
||||
return nil, fmt.Errorf("collection not found: %d", collectionID)
|
||||
}
|
||||
|
||||
// getOrCreateStream get message stream of specified collection.
|
||||
// If stream doesn't exist, call createMsgStream to create for it.
|
||||
func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
if stream, err := mgr.lockGetStream(collectionID); err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
return mgr.createMsgStream(ctx, collectionID)
|
||||
}
|
||||
|
||||
// removeStream remove the corresponding stream of the specified collection. Idempotent.
|
||||
// If stream already exists, remove it, otherwise do nothing.
|
||||
func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) {
|
||||
@ -268,34 +182,19 @@ func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) {
|
||||
defer mgr.mu.Unlock()
|
||||
if info, ok := mgr.infos[collectionID]; ok {
|
||||
decPChanMetrics(info.channelInfos.pchans)
|
||||
info.stream.Close()
|
||||
delete(mgr.infos, collectionID)
|
||||
}
|
||||
log.Info("dml stream removed", zap.Int64("collection_id", collectionID))
|
||||
}
|
||||
|
||||
// removeAllStream remove all message stream.
|
||||
func (mgr *singleTypeChannelsMgr) removeAllStream() {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
for _, info := range mgr.infos {
|
||||
info.stream.Close()
|
||||
decPChanMetrics(info.channelInfos.pchans)
|
||||
}
|
||||
mgr.infos = make(map[UniqueID]streamInfos)
|
||||
log.Info("all dml stream removed")
|
||||
}
|
||||
|
||||
func newSingleTypeChannelsMgr(
|
||||
getChannelsFunc getChannelsFuncType,
|
||||
msgStreamFactory msgstream.Factory,
|
||||
repackFunc repackFuncType,
|
||||
) *singleTypeChannelsMgr {
|
||||
return &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: getChannelsFunc,
|
||||
repackFunc: repackFunc,
|
||||
msgStreamFactory: msgStreamFactory,
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: getChannelsFunc,
|
||||
repackFunc: repackFunc,
|
||||
}
|
||||
}
|
||||
|
||||
@ -315,25 +214,16 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
|
||||
return mgr.dmlChannelsMgr.getVChannels(collectionID)
|
||||
}
|
||||
|
||||
func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
|
||||
return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
|
||||
}
|
||||
|
||||
func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {
|
||||
mgr.dmlChannelsMgr.removeStream(collectionID)
|
||||
}
|
||||
|
||||
func (mgr *channelsMgrImpl) removeAllDMLStream() {
|
||||
mgr.dmlChannelsMgr.removeAllStream()
|
||||
}
|
||||
|
||||
// newChannelsMgrImpl constructs a channels manager.
|
||||
func newChannelsMgrImpl(
|
||||
getDmlChannelsFunc getChannelsFuncType,
|
||||
dmlRepackFunc repackFuncType,
|
||||
msgStreamFactory msgstream.Factory,
|
||||
) *channelsMgrImpl {
|
||||
return &channelsMgrImpl{
|
||||
dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, msgStreamFactory, dmlRepackFunc),
|
||||
dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, dmlRepackFunc),
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -28,8 +27,6 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
func Test_removeDuplicate(t *testing.T) {
|
||||
@ -205,220 +202,11 @@ func Test_singleTypeChannelsMgr_getVChannels(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_createStream(t *testing.T) {
|
||||
t.Run("failed to create msgstream", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
_, err := createStream(context.TODO(), factory, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed to create query msgstream", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
_, err := createStream(context.TODO(), factory, nil, nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
_, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
|
||||
paramtable.Init()
|
||||
t.Run("re-create", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
stream, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
|
||||
t.Run("concurrent create", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
stopCh := make(chan struct{})
|
||||
readyCh := make(chan struct{})
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
close(readyCh)
|
||||
<-stopCh
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
msgStreamFactory: factory,
|
||||
repackFunc: nil,
|
||||
}
|
||||
|
||||
firstStream := streamInfos{stream: newMockMsgStream()}
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stream, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
}()
|
||||
// make sure create msg stream has run at getchannels
|
||||
<-readyCh
|
||||
// mock create stream for same collection in same time.
|
||||
m.mu.Lock()
|
||||
m.infos[100] = firstStream
|
||||
m.mu.Unlock()
|
||||
|
||||
close(stopCh)
|
||||
wg.Wait()
|
||||
})
|
||||
t.Run("failed to get channels", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed to create message stream", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
m := &singleTypeChannelsMgr{
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
msgStreamFactory: factory,
|
||||
repackFunc: nil,
|
||||
}
|
||||
_, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
msgStreamFactory: factory,
|
||||
repackFunc: nil,
|
||||
}
|
||||
stream, err := m.createMsgStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
stream, err = m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_lockGetStream(t *testing.T) {
|
||||
t.Run("collection not found", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
}
|
||||
_, err := m.lockGetStream(100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
stream, err := m.lockGetStream(100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
|
||||
t.Run("exist", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {stream: newMockMsgStream()},
|
||||
},
|
||||
}
|
||||
stream, err := m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
|
||||
t.Run("failed to create", func(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{},
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{}, errors.New("mock")
|
||||
},
|
||||
}
|
||||
_, err := m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("get after create", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return newMockMsgStream(), nil
|
||||
}
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: make(map[UniqueID]streamInfos),
|
||||
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
|
||||
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
|
||||
},
|
||||
msgStreamFactory: factory,
|
||||
repackFunc: nil,
|
||||
}
|
||||
stream, err := m.getOrCreateStream(context.TODO(), 100)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stream)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_removeStream(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {
|
||||
stream: newMockMsgStream(),
|
||||
},
|
||||
100: {},
|
||||
},
|
||||
}
|
||||
m.removeStream(100)
|
||||
_, err := m.lockGetStream(100)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_singleTypeChannelsMgr_removeAllStream(t *testing.T) {
|
||||
m := &singleTypeChannelsMgr{
|
||||
infos: map[UniqueID]streamInfos{
|
||||
100: {
|
||||
stream: newMockMsgStream(),
|
||||
},
|
||||
},
|
||||
}
|
||||
m.removeAllStream()
|
||||
_, err := m.lockGetStream(100)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@ -46,7 +45,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/ctokenizer"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
@ -255,7 +253,6 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateDatabaseRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -323,7 +320,6 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropDatabaseRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -452,7 +448,6 @@ func (node *Proxy) AlterDatabase(ctx context.Context, request *milvuspb.AlterDat
|
||||
Condition: NewTaskCondition(ctx),
|
||||
AlterDatabaseRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -667,7 +662,6 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
|
||||
DropCollectionRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
chMgr: node.chMgr,
|
||||
chTicker: node.chTicker,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -833,7 +827,6 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
|
||||
Condition: NewTaskCondition(ctx),
|
||||
LoadCollectionRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -908,7 +901,6 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ReleaseCollectionRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -1278,7 +1270,6 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC
|
||||
Condition: NewTaskCondition(ctx),
|
||||
AlterCollectionRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -1343,7 +1334,6 @@ func (node *Proxy) AlterCollectionField(ctx context.Context, request *milvuspb.A
|
||||
Condition: NewTaskCondition(ctx),
|
||||
AlterCollectionFieldRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -1616,7 +1606,6 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar
|
||||
Condition: NewTaskCondition(ctx),
|
||||
LoadPartitionsRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(
|
||||
@ -1682,7 +1671,6 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele
|
||||
Condition: NewTaskCondition(ctx),
|
||||
ReleasePartitionsRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
method := "ReleasePartitions"
|
||||
@ -2082,11 +2070,10 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
|
||||
defer sp.End()
|
||||
|
||||
cit := &createIndexTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
req: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
req: request,
|
||||
mixCoord: node.mixCoord,
|
||||
}
|
||||
|
||||
method := "CreateIndex"
|
||||
@ -2152,11 +2139,10 @@ func (node *Proxy) AlterIndex(ctx context.Context, request *milvuspb.AlterIndexR
|
||||
defer sp.End()
|
||||
|
||||
task := &alterIndexTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
req: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
req: request,
|
||||
mixCoord: node.mixCoord,
|
||||
}
|
||||
|
||||
method := "AlterIndex"
|
||||
@ -2370,11 +2356,10 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq
|
||||
defer sp.End()
|
||||
|
||||
dit := &dropIndexTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropIndexRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropIndexRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
}
|
||||
|
||||
method := "DropIndex"
|
||||
@ -2630,17 +2615,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
||||
},
|
||||
},
|
||||
idAllocator: node.rowIDAllocator,
|
||||
segIDAssigner: node.segAssigner,
|
||||
chMgr: node.chMgr,
|
||||
chTicker: node.chTicker,
|
||||
schemaTimestamp: request.SchemaTimestamp,
|
||||
}
|
||||
var enqueuedTask task = it
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
enqueuedTask = &insertTaskByStreamingService{
|
||||
insertTask: it,
|
||||
}
|
||||
}
|
||||
|
||||
constructFailedResponse := func(err error) *milvuspb.MutationResult {
|
||||
numRows := request.NumRows
|
||||
@ -2657,7 +2634,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
|
||||
|
||||
log.Debug("Enqueue insert request in Proxy")
|
||||
|
||||
if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil {
|
||||
if err := node.sched.dmQueue.Enqueue(it); err != nil {
|
||||
log.Warn("Failed to enqueue insert task: " + err.Error())
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
metrics.AbandonLabel, request.GetDbName(), request.GetCollectionName()).Inc()
|
||||
@ -2765,7 +2742,6 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
|
||||
idAllocator: node.rowIDAllocator,
|
||||
tsoAllocatorIns: node.tsoAllocator,
|
||||
chMgr: node.chMgr,
|
||||
chTicker: node.chTicker,
|
||||
queue: node.sched.dmQueue,
|
||||
lb: node.lbPolicy,
|
||||
limiter: limiter,
|
||||
@ -2874,23 +2850,15 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
|
||||
},
|
||||
|
||||
idAllocator: node.rowIDAllocator,
|
||||
segIDAssigner: node.segAssigner,
|
||||
chMgr: node.chMgr,
|
||||
chTicker: node.chTicker,
|
||||
schemaTimestamp: request.SchemaTimestamp,
|
||||
}
|
||||
var enqueuedTask task = it
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
enqueuedTask = &upsertTaskByStreamingService{
|
||||
upsertTask: it,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Enqueue upsert request in Proxy",
|
||||
zap.Int("len(FieldsData)", len(request.FieldsData)),
|
||||
zap.Int("len(HashKeys)", len(request.HashKeys)))
|
||||
|
||||
if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil {
|
||||
if err := node.sched.dmQueue.Enqueue(it); err != nil {
|
||||
log.Info("Failed to enqueue upsert task",
|
||||
zap.Error(err))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
|
||||
@ -3561,11 +3529,11 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
|
||||
defer sp.End()
|
||||
|
||||
ft := &flushTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
FlushRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
FlushRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
chMgr: node.chMgr,
|
||||
}
|
||||
|
||||
method := "Flush"
|
||||
@ -3578,16 +3546,7 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
|
||||
zap.Any("collections", request.CollectionNames))
|
||||
|
||||
log.Debug(rpcReceived(method))
|
||||
|
||||
var enqueuedTask task = ft
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
enqueuedTask = &flushTaskByStreamingService{
|
||||
flushTask: ft,
|
||||
chMgr: node.chMgr,
|
||||
}
|
||||
}
|
||||
|
||||
if err := node.sched.dcQueue.Enqueue(enqueuedTask); err != nil {
|
||||
if err := node.sched.dcQueue.Enqueue(ft); err != nil {
|
||||
log.Warn(
|
||||
rpcFailedToEnqueue(method),
|
||||
zap.Error(err))
|
||||
@ -3846,7 +3805,6 @@ func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAlia
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateAliasRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
}
|
||||
|
||||
method := "CreateAlias"
|
||||
@ -4034,11 +3992,10 @@ func (node *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasReq
|
||||
defer sp.End()
|
||||
|
||||
dat := &DropAliasTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropAliasRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
DropAliasRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
}
|
||||
|
||||
method := "DropAlias"
|
||||
@ -4098,11 +4055,10 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR
|
||||
defer sp.End()
|
||||
|
||||
aat := &AlterAliasTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
AlterAliasRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
AlterAliasRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
}
|
||||
|
||||
method := "AlterAlias"
|
||||
@ -4174,11 +4130,11 @@ func (node *Proxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllReque
|
||||
defer sp.End()
|
||||
|
||||
ft := &flushAllTask{
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
FlushAllRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
replicateMsgStream: node.replicateMsgStream,
|
||||
ctx: ctx,
|
||||
Condition: NewTaskCondition(ctx),
|
||||
FlushAllRequest: request,
|
||||
mixCoord: node.mixCoord,
|
||||
chMgr: node.chMgr,
|
||||
}
|
||||
|
||||
method := "FlushAll"
|
||||
@ -4191,15 +4147,7 @@ func (node *Proxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllReque
|
||||
|
||||
log.Debug(rpcReceived(method))
|
||||
|
||||
var enqueuedTask task = ft
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
enqueuedTask = &flushAllTaskbyStreamingService{
|
||||
flushAllTask: ft,
|
||||
chMgr: node.chMgr,
|
||||
}
|
||||
}
|
||||
|
||||
if err := node.sched.dcQueue.Enqueue(enqueuedTask); err != nil {
|
||||
if err := node.sched.dcQueue.Enqueue(ft); err != nil {
|
||||
log.Warn(rpcFailedToEnqueue(method), zap.Error(err))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel, request.GetDbName(), "").Inc()
|
||||
resp.Status = merr.Status(err)
|
||||
@ -5198,9 +5146,6 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre
|
||||
zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
@ -5273,9 +5218,6 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
|
||||
zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
@ -5306,9 +5248,6 @@ func (node *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCre
|
||||
zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
@ -5372,9 +5311,6 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque
|
||||
log.Warn("fail to create role", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -5407,9 +5343,6 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest)
|
||||
zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -5439,9 +5372,6 @@ func (node *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUse
|
||||
log.Warn("fail to operate user role", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -5624,9 +5554,6 @@ func (node *Proxy) OperatePrivilegeV2(ctx context.Context, req *milvuspb.Operate
|
||||
}
|
||||
}
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -5675,9 +5602,6 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr
|
||||
}
|
||||
}
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -5914,11 +5838,6 @@ func (node *Proxy) RenameCollection(ctx context.Context, req *milvuspb.RenameCol
|
||||
log.Warn("failed to rename collection", zap.Error(err))
|
||||
return merr.Status(err), err
|
||||
}
|
||||
|
||||
if merr.Ok(resp) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@ -6432,136 +6351,9 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest
|
||||
}
|
||||
|
||||
func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if req.GetChannelName() == "" {
|
||||
log.Ctx(ctx).Warn("channel name is empty")
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(merr.WrapErrParameterInvalidMsg("invalid channel name for the replicate message request")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// get the latest position of the replicate msg channel
|
||||
replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
if req.GetChannelName() == replicateMsgChannel {
|
||||
msgID, err := msgstream.GetChannelLatestMsgID(ctx, node.factory, replicateMsgChannel)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
position := &msgpb.MsgPosition{
|
||||
ChannelName: replicateMsgChannel,
|
||||
MsgID: msgID,
|
||||
}
|
||||
positionBytes, err := proto.Marshal(position)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to marshal position", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(nil),
|
||||
Position: base64.StdEncoding.EncodeToString(positionBytes),
|
||||
}, nil
|
||||
}
|
||||
|
||||
collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool()
|
||||
ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool()
|
||||
|
||||
// replicate message can be use in two ways, otherwise return error
|
||||
// 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode
|
||||
// 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode
|
||||
if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) {
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(merr.ErrDenyReplicateMessage),
|
||||
}, nil
|
||||
}
|
||||
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: req.BeginTs,
|
||||
EndTs: req.EndTs,
|
||||
Msgs: make([]msgstream.TsMsg, 0),
|
||||
StartPositions: req.StartPositions,
|
||||
EndPositions: req.EndPositions,
|
||||
}
|
||||
checkCollectionReplicateProperty := func(dbName, collectionName string) bool {
|
||||
if !collectionReplicateEnable {
|
||||
return true
|
||||
}
|
||||
replicateID, err := GetReplicateID(ctx, dbName, collectionName)
|
||||
if err != nil {
|
||||
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
|
||||
return false
|
||||
}
|
||||
return replicateID != ""
|
||||
}
|
||||
|
||||
// getTsMsgFromConsumerMsg
|
||||
for i, msgBytes := range req.Msgs {
|
||||
header := commonpb.MsgHeader{}
|
||||
err = proto.Unmarshal(msgBytes, &header)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
if header.GetBase() == nil {
|
||||
log.Ctx(ctx).Warn("msg header base is nil", zap.Int("index", i))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
|
||||
}
|
||||
tsMsg, err := node.replicateStreamManager.GetMsgDispatcher().Unmarshal(msgBytes, header.GetBase().GetMsgType())
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
|
||||
}
|
||||
switch realMsg := tsMsg.(type) {
|
||||
case *msgstream.InsertMsg:
|
||||
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
|
||||
}
|
||||
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
|
||||
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to get segment id", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
if len(assignedSegmentInfos) == 0 {
|
||||
log.Ctx(ctx).Warn("no segment id assigned")
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrNoAssignSegmentID)}, nil
|
||||
}
|
||||
for assignSegmentID := range assignedSegmentInfos {
|
||||
realMsg.SegmentID = assignSegmentID
|
||||
break
|
||||
}
|
||||
case *msgstream.DeleteMsg:
|
||||
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
|
||||
}
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
|
||||
}
|
||||
|
||||
msgStream, err := node.replicateStreamManager.GetReplicateMsgStream(ctx, req.ChannelName)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to get msg stream from the replicate stream manager", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(err),
|
||||
}, nil
|
||||
}
|
||||
messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
|
||||
}
|
||||
var position string
|
||||
if len(messageIDsMap[req.GetChannelName()]) == 0 {
|
||||
log.Ctx(ctx).Warn("no message id returned")
|
||||
} else {
|
||||
messageIDs := messageIDsMap[req.GetChannelName()]
|
||||
position = base64.StdEncoding.EncodeToString(messageIDs[len(messageIDs)-1].Serialize())
|
||||
}
|
||||
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil
|
||||
return &milvuspb.ReplicateMessageResponse{
|
||||
Status: merr.Status(merr.WrapErrServiceUnavailable("not supported in streaming mode")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) {
|
||||
@ -6821,9 +6613,6 @@ func (node *Proxy) CreatePrivilegeGroup(ctx context.Context, req *milvuspb.Creat
|
||||
log.Warn("fail to create privilege group", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -6853,9 +6642,6 @@ func (node *Proxy) DropPrivilegeGroup(ctx context.Context, req *milvuspb.DropPri
|
||||
log.Warn("fail to drop privilege group", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -6919,9 +6705,6 @@ func (node *Proxy) OperatePrivilegeGroup(ctx context.Context, req *milvuspb.Oper
|
||||
log.Warn("fail to operate privilege group", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
if merr.Ok(result) {
|
||||
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@ -18,12 +18,10 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/mockey"
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -31,38 +29,27 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
mhttp "github.com/milvus-io/milvus/internal/http"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
mqcommon "github.com/milvus-io/milvus/pkg/v2/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/resource"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
@ -254,7 +241,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
|
||||
qc.EXPECT().ShowLoadCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
|
||||
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns, node.factory)
|
||||
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns)
|
||||
assert.NoError(t, err)
|
||||
node.sched.Start()
|
||||
defer node.sched.Close()
|
||||
@ -336,7 +323,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
|
||||
Status: merr.Success(),
|
||||
}, nil).Maybe()
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns, node.factory)
|
||||
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns)
|
||||
assert.NoError(t, err)
|
||||
node.sched.Start()
|
||||
defer node.sched.Close()
|
||||
@ -393,11 +380,7 @@ func createTestProxy() *Proxy {
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
|
||||
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, _ = node.factory.NewMsgStream(node.ctx)
|
||||
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
|
||||
|
||||
node.sched, _ = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched, _ = newTaskScheduler(ctx, node.tsoAllocator)
|
||||
node.sched.Start()
|
||||
|
||||
return node
|
||||
@ -418,10 +401,12 @@ func TestProxy_FlushAll_NoDatabase(t *testing.T) {
|
||||
mockey.Mock(paramtable.Init).Return().Build()
|
||||
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
|
||||
|
||||
// Mock grpc mix coord client FlushAll method
|
||||
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) {
|
||||
return &datapb.FlushAllResponse{Status: successStatus}, nil
|
||||
mockey.Mock((*grpcmixcoordclient.Client).ListDatabases).To(func(ctx context.Context, req *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) {
|
||||
return &milvuspb.ListDatabasesResponse{Status: successStatus}, nil
|
||||
}).Build()
|
||||
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
|
||||
return &milvuspb.ShowCollectionsResponse{Status: successStatus}, nil
|
||||
}).Build()
|
||||
|
||||
// Act: Execute test
|
||||
@ -454,10 +439,13 @@ func TestProxy_FlushAll_WithDefaultDatabase(t *testing.T) {
|
||||
mockey.Mock(paramtable.Init).Return().Build()
|
||||
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
|
||||
|
||||
// Mock grpc mix coord client FlushAll method for default database
|
||||
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) {
|
||||
return &datapb.FlushAllResponse{Status: successStatus}, nil
|
||||
mockey.Mock((*grpcmixcoordclient.Client).ListDatabases).To(func(ctx context.Context, req *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) {
|
||||
return &milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil
|
||||
}).Build()
|
||||
// Mock grpc mix coord client FlushAll method for default database
|
||||
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
|
||||
return &milvuspb.ShowCollectionsResponse{Status: successStatus}, nil
|
||||
}).Build()
|
||||
|
||||
// Act: Execute test
|
||||
@ -490,9 +478,8 @@ func TestProxy_FlushAll_DatabaseNotExist(t *testing.T) {
|
||||
mockey.Mock(paramtable.Init).Return().Build()
|
||||
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
|
||||
|
||||
// Mock grpc mix coord client FlushAll method for non-existent database
|
||||
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) {
|
||||
return &datapb.FlushAllResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil
|
||||
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
|
||||
return &milvuspb.ShowCollectionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil
|
||||
}).Build()
|
||||
|
||||
// Act: Execute test
|
||||
@ -889,18 +876,13 @@ func TestProxyCreateDatabase(t *testing.T) {
|
||||
}
|
||||
node.simpleLimiter = NewSimpleLimiter(0, 0)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
assert.NoError(t, err)
|
||||
err = node.sched.Start()
|
||||
assert.NoError(t, err)
|
||||
defer node.sched.Close()
|
||||
|
||||
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
|
||||
assert.NoError(t, err)
|
||||
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
|
||||
|
||||
t.Run("create database fail", func(t *testing.T) {
|
||||
mixc := mocks.NewMockMixCoordClient(t)
|
||||
mixc.On("CreateDatabase", mock.Anything, mock.Anything).
|
||||
@ -949,18 +931,13 @@ func TestProxyDropDatabase(t *testing.T) {
|
||||
}
|
||||
node.simpleLimiter = NewSimpleLimiter(0, 0)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
assert.NoError(t, err)
|
||||
err = node.sched.Start()
|
||||
assert.NoError(t, err)
|
||||
defer node.sched.Close()
|
||||
|
||||
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
|
||||
assert.NoError(t, err)
|
||||
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
|
||||
|
||||
t.Run("drop database fail", func(t *testing.T) {
|
||||
mixc := mocks.NewMockMixCoordClient(t)
|
||||
mixc.On("DropDatabase", mock.Anything, mock.Anything).
|
||||
@ -1007,7 +984,7 @@ func TestProxyListDatabase(t *testing.T) {
|
||||
}
|
||||
node.simpleLimiter = NewSimpleLimiter(0, 0)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
assert.NoError(t, err)
|
||||
err = node.sched.Start()
|
||||
@ -1063,7 +1040,7 @@ func TestProxyAlterDatabase(t *testing.T) {
|
||||
}
|
||||
node.simpleLimiter = NewSimpleLimiter(0, 0)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
assert.NoError(t, err)
|
||||
err = node.sched.Start()
|
||||
@ -1116,7 +1093,7 @@ func TestProxyDescribeDatabase(t *testing.T) {
|
||||
}
|
||||
node.simpleLimiter = NewSimpleLimiter(0, 0)
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
|
||||
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
|
||||
node.sched.ddQueue.setMaxTaskNum(10)
|
||||
assert.NoError(t, err)
|
||||
err = node.sched.Start()
|
||||
@ -1347,7 +1324,7 @@ func TestProxy_Delete(t *testing.T) {
|
||||
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
|
||||
assert.NoError(t, err)
|
||||
|
||||
queue, err := newTaskScheduler(ctx, tsoAllocator, nil)
|
||||
queue, err := newTaskScheduler(ctx, tsoAllocator)
|
||||
assert.NoError(t, err)
|
||||
|
||||
node := &Proxy{chMgr: chMgr, rowIDAllocator: idAllocator, sched: queue}
|
||||
@ -1358,287 +1335,7 @@ func TestProxy_Delete(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxy_ReplicateMessage(t *testing.T) {
|
||||
paramtable.Init()
|
||||
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
t.Run("proxy unhealthy", func(t *testing.T) {
|
||||
node := &Proxy{}
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
|
||||
resp, err := node.ReplicateMessage(context.TODO(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
|
||||
})
|
||||
|
||||
t.Run("not backup instance", func(t *testing.T) {
|
||||
node := &Proxy{}
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
resp, err := node.ReplicateMessage(context.TODO(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
|
||||
})
|
||||
|
||||
t.Run("empty channel name", func(t *testing.T) {
|
||||
node := &Proxy{}
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
|
||||
resp, err := node.ReplicateMessage(context.TODO(), nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
|
||||
})
|
||||
|
||||
t.Run("fail to get msg stream", func(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock error: get msg stream")
|
||||
}
|
||||
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
|
||||
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
|
||||
|
||||
node := &Proxy{
|
||||
replicateStreamManager: manager,
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
|
||||
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ChannelName: "unit_test_replicate_message"})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
|
||||
})
|
||||
|
||||
t.Run("get latest position", func(t *testing.T) {
|
||||
base64DecodeMsgPosition := func(position string) (*msgstream.MsgPosition, error) {
|
||||
decodeBytes, err := base64.StdEncoding.DecodeString(position)
|
||||
if err != nil {
|
||||
log.Warn("fail to decode the position", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
msgPosition := &msgstream.MsgPosition{}
|
||||
err = proto.Unmarshal(decodeBytes, msgPosition)
|
||||
if err != nil {
|
||||
log.Warn("fail to unmarshal the position", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return msgPosition, nil
|
||||
}
|
||||
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
|
||||
factory := dependency.NewMockFactory(t)
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMsgID := mqcommon.NewMockMessageID(t)
|
||||
|
||||
factory.EXPECT().NewMsgStream(mock.Anything).Return(stream, nil).Once()
|
||||
mockMsgID.EXPECT().Serialize().Return([]byte("mock")).Once()
|
||||
stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
|
||||
stream.EXPECT().GetLatestMsgID(mock.Anything).Return(mockMsgID, nil).Once()
|
||||
stream.EXPECT().Close().Return()
|
||||
node := &Proxy{
|
||||
factory: factory,
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
|
||||
{
|
||||
p, err := base64DecodeMsgPosition(resp.GetPosition())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("mock"), p.MsgID)
|
||||
}
|
||||
|
||||
factory.EXPECT().NewMsgStream(mock.Anything).Return(nil, errors.New("mock")).Once()
|
||||
resp, err = node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
|
||||
})
|
||||
|
||||
t.Run("invalid msg pack", func(t *testing.T) {
|
||||
node := &Proxy{
|
||||
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
{
|
||||
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "unit_test_replicate_message",
|
||||
Msgs: [][]byte{{1, 2, 3}},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
|
||||
}
|
||||
|
||||
{
|
||||
timeTickMsg := &msgstream.TimeTickMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 1,
|
||||
EndTimestamp: 10,
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
TimeTickMsg: &msgpb.TimeTickMsg{},
|
||||
}
|
||||
msgBytes, _ := timeTickMsg.Marshal(timeTickMsg)
|
||||
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "unit_test_replicate_message",
|
||||
Msgs: [][]byte{msgBytes.([]byte)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
log.Info("resp", zap.Any("resp", resp))
|
||||
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
|
||||
}
|
||||
|
||||
{
|
||||
timeTickMsg := &msgstream.TimeTickMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 1,
|
||||
EndTimestamp: 10,
|
||||
HashValues: []uint32{0},
|
||||
},
|
||||
TimeTickMsg: &msgpb.TimeTickMsg{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType(-1)),
|
||||
commonpbutil.WithTimeStamp(10),
|
||||
commonpbutil.WithSourceID(-1),
|
||||
),
|
||||
},
|
||||
}
|
||||
msgBytes, _ := timeTickMsg.Marshal(timeTickMsg)
|
||||
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "unit_test_replicate_message",
|
||||
Msgs: [][]byte{msgBytes.([]byte)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
log.Info("resp", zap.Any("resp", resp))
|
||||
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
paramtable.Init()
|
||||
factory := newMockMsgStreamFactory()
|
||||
msgStreamObj := msgstream.NewMockMsgStream(t)
|
||||
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
|
||||
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
|
||||
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return()
|
||||
msgStreamObj.EXPECT().Close().Return()
|
||||
mockMsgID1 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2"))
|
||||
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
"unit_test_replicate_message": {mockMsgID1, mockMsgID2},
|
||||
}, nil)
|
||||
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return msgStreamObj, nil
|
||||
}
|
||||
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
|
||||
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
|
||||
|
||||
ctx := context.Background()
|
||||
dataCoord := &mockDataCoord{}
|
||||
dataCoord.expireTime = Timestamp(1000)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Start()
|
||||
|
||||
node := &Proxy{
|
||||
replicateStreamManager: manager,
|
||||
segAssigner: segAllocator,
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 4,
|
||||
EndTimestamp: 10,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: "unit_test_replicate_message",
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 10001,
|
||||
Timestamp: 10,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: "unit_test_replicate_message_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
SegmentID: 33,
|
||||
Timestamps: []uint64{10},
|
||||
RowIDs: []int64{66},
|
||||
NumRows: 1,
|
||||
},
|
||||
}
|
||||
msgBytes, _ := insertMsg.Marshal(insertMsg)
|
||||
|
||||
replicateRequest := &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "unit_test_replicate_message",
|
||||
BeginTs: 1,
|
||||
EndTs: 10,
|
||||
Msgs: [][]byte{msgBytes.([]byte)},
|
||||
StartPositions: []*msgpb.MsgPosition{
|
||||
{ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 1")},
|
||||
},
|
||||
EndPositions: []*msgpb.MsgPosition{
|
||||
{ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 2")},
|
||||
},
|
||||
}
|
||||
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock message id 2")), resp.GetPosition())
|
||||
|
||||
res := resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
|
||||
assert.NotNil(t, res)
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
{
|
||||
broadcastMock.Unset()
|
||||
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast"))
|
||||
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
|
||||
resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
{
|
||||
broadcastMock.Unset()
|
||||
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
"unit_test_replicate_message": {},
|
||||
}, nil)
|
||||
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
|
||||
assert.Empty(t, resp.GetPosition())
|
||||
resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
|
||||
time.Sleep(2 * time.Second)
|
||||
broadcastMock.Unset()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestProxy_ImportV2(t *testing.T) {
|
||||
wal := mock_streaming.NewMockWALAccesser(t)
|
||||
b := mock_streaming.NewMockBroadcast(t)
|
||||
wal.EXPECT().Broadcast().Return(b).Maybe()
|
||||
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil).Maybe()
|
||||
streaming.SetWALForTest(wal)
|
||||
defer streaming.RecoverWALForTest()
|
||||
ctx := context.Background()
|
||||
mockErr := errors.New("mock error")
|
||||
|
||||
@ -1661,7 +1358,7 @@ func TestProxy_ImportV2(t *testing.T) {
|
||||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
|
||||
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator)
|
||||
assert.NoError(t, err)
|
||||
node.sched = scheduler
|
||||
err = node.sched.Start()
|
||||
@ -1916,333 +1613,6 @@ func TestRegisterRestRouter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicateMessageForCollectionMode(t *testing.T) {
|
||||
paramtable.Init()
|
||||
ctx := context.Background()
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 10,
|
||||
EndTimestamp: 10,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: "foo",
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 10001,
|
||||
Timestamp: 10,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: "foo_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
SegmentID: 33,
|
||||
Timestamps: []uint64{10},
|
||||
RowIDs: []int64{66},
|
||||
NumRows: 1,
|
||||
},
|
||||
}
|
||||
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: 20,
|
||||
EndTimestamp: 20,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: "foo",
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
DeleteRequest: &msgpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Delete,
|
||||
MsgID: 10002,
|
||||
Timestamp: 20,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: "foo_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
},
|
||||
}
|
||||
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
|
||||
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
|
||||
t.Run("replicate message in the replicate collection mode", func(t *testing.T) {
|
||||
defer func() {
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
|
||||
}()
|
||||
|
||||
{
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "false")
|
||||
p := &Proxy{}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, merr.Error(r.Status))
|
||||
}
|
||||
|
||||
{
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
|
||||
p := &Proxy{}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, merr.Error(r.Status))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("replicate message for the replicate collection mode", func(t *testing.T) {
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
|
||||
defer func() {
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
|
||||
}()
|
||||
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Twice()
|
||||
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Twice()
|
||||
globalMetaCache = mockCache
|
||||
|
||||
{
|
||||
p := &Proxy{
|
||||
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
|
||||
}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
Msgs: [][]byte{insertMsgBytes.([]byte)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
|
||||
}
|
||||
|
||||
{
|
||||
p := &Proxy{
|
||||
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
|
||||
}
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "foo",
|
||||
Msgs: [][]byte{deleteMsgBytes.([]byte)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAlterCollectionReplicateProperty(t *testing.T) {
|
||||
paramtable.Init()
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
|
||||
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
|
||||
defer func() {
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
|
||||
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
|
||||
}()
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
||||
replicateID: "local-milvus",
|
||||
}, nil).Maybe()
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe()
|
||||
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
|
||||
globalMetaCache = mockCache
|
||||
|
||||
factory := newMockMsgStreamFactory()
|
||||
msgStreamObj := msgstream.NewMockMsgStream(t)
|
||||
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe()
|
||||
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe()
|
||||
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe()
|
||||
msgStreamObj.EXPECT().Close().Return().Maybe()
|
||||
mockMsgID1 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2 := mqcommon.NewMockMessageID(t)
|
||||
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe()
|
||||
msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
|
||||
"alter_property": {mockMsgID1, mockMsgID2},
|
||||
}, nil).Maybe()
|
||||
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return msgStreamObj, nil
|
||||
}
|
||||
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
|
||||
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
|
||||
|
||||
ctx := context.Background()
|
||||
var startTt uint64 = 10
|
||||
startTime := time.Now()
|
||||
dataCoord := &mockDataCoord{}
|
||||
dataCoord.expireTime = Timestamp(1000)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, func() Timestamp {
|
||||
return Timestamp(time.Since(startTime).Seconds()) + startTt
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Start()
|
||||
|
||||
mockMixcoord := mocks.NewMockMixCoordClient(t)
|
||||
mockMixcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *rootcoordpb.AllocTimestampRequest, option ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
|
||||
return &rootcoordpb.AllocTimestampResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
Timestamp: Timestamp(time.Since(startTime).Seconds()) + startTt,
|
||||
}, nil
|
||||
})
|
||||
mockMixcoord.EXPECT().AlterCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
}, nil)
|
||||
|
||||
p := &Proxy{
|
||||
ctx: ctx,
|
||||
replicateStreamManager: manager,
|
||||
segAssigner: segAllocator,
|
||||
mixCoord: mockMixcoord,
|
||||
}
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
p.sched, err = newTaskScheduler(p.ctx, tsoAllocatorIns, p.factory)
|
||||
assert.NoError(t, err)
|
||||
p.sched.Start()
|
||||
defer p.sched.Close()
|
||||
p.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
|
||||
getInsertMsgBytes := func(channel string, ts uint64) []byte {
|
||||
insertMsg := &msgstream.InsertMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: channel,
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
InsertRequest: &msgpb.InsertRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Insert,
|
||||
MsgID: 10001,
|
||||
Timestamp: ts,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: channel + "_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
SegmentID: 33,
|
||||
Timestamps: []uint64{ts},
|
||||
RowIDs: []int64{66},
|
||||
NumRows: 1,
|
||||
},
|
||||
}
|
||||
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
|
||||
return insertMsgBytes.([]byte)
|
||||
}
|
||||
getDeleteMsgBytes := func(channel string, ts uint64) []byte {
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
HashValues: []uint32{0},
|
||||
MsgPosition: &msgstream.MsgPosition{
|
||||
ChannelName: "foo",
|
||||
MsgID: []byte("mock message id 2"),
|
||||
},
|
||||
},
|
||||
DeleteRequest: &msgpb.DeleteRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_Delete,
|
||||
MsgID: 10002,
|
||||
Timestamp: ts,
|
||||
SourceID: -1,
|
||||
},
|
||||
ShardName: channel + "_v1",
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
PartitionName: "_default",
|
||||
DbID: 1,
|
||||
CollectionID: 11,
|
||||
PartitionID: 22,
|
||||
},
|
||||
}
|
||||
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
|
||||
return deleteMsgBytes.([]byte)
|
||||
}
|
||||
|
||||
go func() {
|
||||
// replicate message
|
||||
var replicateResp *milvuspb.ReplicateMessageResponse
|
||||
var err error
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_1",
|
||||
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+5)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_2",
|
||||
Msgs: [][]byte{getDeleteMsgBytes("alter_property_2", startTt+5)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_1",
|
||||
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+10)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
|
||||
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
|
||||
ChannelName: "alter_property_2",
|
||||
Msgs: [][]byte{getInsertMsgBytes("alter_property_2", startTt+10)},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
|
||||
}()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// alter collection property
|
||||
statusResp, err := p.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
|
||||
DbName: "default",
|
||||
CollectionName: "foo_collection",
|
||||
Properties: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "replicate.endTS",
|
||||
Value: "1",
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, merr.Ok(statusResp))
|
||||
}
|
||||
|
||||
func TestRunAnalyzer(t *testing.T) {
|
||||
paramtable.Init()
|
||||
ctx := context.Background()
|
||||
@ -2254,7 +1624,7 @@ func TestRunAnalyzer(t *testing.T) {
|
||||
p := &Proxy{}
|
||||
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
sched, err := newTaskScheduler(ctx, tsoAllocatorIns, p.factory)
|
||||
sched, err := newTaskScheduler(ctx, tsoAllocatorIns)
|
||||
require.NoError(t, err)
|
||||
sched.Start()
|
||||
defer sched.Close()
|
||||
|
||||
@ -2,12 +2,7 @@
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
msgstream "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MockChannelsMgr is an autogenerated mock type for the channelsMgr type
|
||||
type MockChannelsMgr struct {
|
||||
@ -80,65 +75,6 @@ func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]stri
|
||||
return _c
|
||||
}
|
||||
|
||||
// getOrCreateDmlStream provides a mock function with given fields: ctx, collectionID
|
||||
func (_m *MockChannelsMgr) getOrCreateDmlStream(ctx context.Context, collectionID int64) (msgstream.MsgStream, error) {
|
||||
ret := _m.Called(ctx, collectionID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for getOrCreateDmlStream")
|
||||
}
|
||||
|
||||
var r0 msgstream.MsgStream
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) (msgstream.MsgStream, error)); ok {
|
||||
return rf(ctx, collectionID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int64) msgstream.MsgStream); ok {
|
||||
r0 = rf(ctx, collectionID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(msgstream.MsgStream)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
|
||||
r1 = rf(ctx, collectionID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockChannelsMgr_getOrCreateDmlStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getOrCreateDmlStream'
|
||||
type MockChannelsMgr_getOrCreateDmlStream_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// getOrCreateDmlStream is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - collectionID int64
|
||||
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(ctx interface{}, collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", ctx, collectionID)}
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(ctx context.Context, collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStream, _a1 error) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(context.Context, int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// getVChannels provides a mock function with given fields: collectionID
|
||||
func (_m *MockChannelsMgr) getVChannels(collectionID int64) ([]string, error) {
|
||||
ret := _m.Called(collectionID)
|
||||
@ -197,38 +133,6 @@ func (_c *MockChannelsMgr_getVChannels_Call) RunAndReturn(run func(int64) ([]str
|
||||
return _c
|
||||
}
|
||||
|
||||
// removeAllDMLStream provides a mock function with no fields
|
||||
func (_m *MockChannelsMgr) removeAllDMLStream() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockChannelsMgr_removeAllDMLStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeAllDMLStream'
|
||||
type MockChannelsMgr_removeAllDMLStream_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// removeAllDMLStream is a helper method to define mock.On call
|
||||
func (_e *MockChannelsMgr_Expecter) removeAllDMLStream() *MockChannelsMgr_removeAllDMLStream_Call {
|
||||
return &MockChannelsMgr_removeAllDMLStream_Call{Call: _e.mock.On("removeAllDMLStream")}
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_removeAllDMLStream_Call) Run(run func()) *MockChannelsMgr_removeAllDMLStream_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_removeAllDMLStream_Call) Return() *MockChannelsMgr_removeAllDMLStream_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockChannelsMgr_removeAllDMLStream_Call) RunAndReturn(run func()) *MockChannelsMgr_removeAllDMLStream_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// removeDMLStream provides a mock function with given fields: collectionID
|
||||
func (_m *MockChannelsMgr) removeDMLStream(collectionID int64) {
|
||||
_m.Called(collectionID)
|
||||
|
||||
@ -18,23 +18,12 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
@ -103,197 +92,3 @@ func genInsertMsgsByPartition(ctx context.Context,
|
||||
|
||||
return repackedMsgs, nil
|
||||
}
|
||||
|
||||
func repackInsertDataByPartition(ctx context.Context,
|
||||
partitionName string,
|
||||
rowOffsets []int,
|
||||
channelName string,
|
||||
insertMsg *msgstream.InsertMsg,
|
||||
segIDAssigner *segIDAssigner,
|
||||
) ([]msgstream.TsMsg, error) {
|
||||
res := make([]msgstream.TsMsg, 0)
|
||||
|
||||
maxTs := Timestamp(0)
|
||||
for _, offset := range rowOffsets {
|
||||
ts := insertMsg.Timestamps[offset]
|
||||
if maxTs < ts {
|
||||
maxTs = ts
|
||||
}
|
||||
}
|
||||
|
||||
partitionID, err := globalMetaCache.GetPartitionID(ctx, insertMsg.GetDbName(), insertMsg.CollectionName, partitionName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
beforeAssign := time.Now()
|
||||
assignedSegmentInfos, err := segIDAssigner.GetSegmentID(insertMsg.CollectionID, partitionID, channelName, uint32(len(rowOffsets)), maxTs)
|
||||
metrics.ProxyAssignSegmentIDLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(time.Since(beforeAssign).Milliseconds()))
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error("allocate segmentID for insert data failed",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.String("channelName", channelName),
|
||||
zap.Int("allocate count", len(rowOffsets)),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startPos := 0
|
||||
for segmentID, count := range assignedSegmentInfos {
|
||||
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
|
||||
msgs, err := genInsertMsgsByPartition(ctx, segmentID, partitionID, partitionName, subRowOffsets, channelName, insertMsg)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("repack insert data to insert msgs failed",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.Int64("partitionID", partitionID),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
res = append(res, msgs...)
|
||||
startPos += int(count)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func setMsgID(ctx context.Context,
|
||||
msgs []msgstream.TsMsg,
|
||||
idAllocator *allocator.IDAllocator,
|
||||
) error {
|
||||
var idBegin int64
|
||||
var err error
|
||||
|
||||
err = retry.Do(ctx, func() error {
|
||||
idBegin, _, err = idAllocator.Alloc(uint32(len(msgs)))
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error("failed to allocate msg id", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
for i, msg := range msgs {
|
||||
msg.SetID(idBegin + UniqueID(i))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func repackInsertData(ctx context.Context,
|
||||
channelNames []string,
|
||||
insertMsg *msgstream.InsertMsg,
|
||||
result *milvuspb.MutationResult,
|
||||
idAllocator *allocator.IDAllocator,
|
||||
segIDAssigner *segIDAssigner,
|
||||
) (*msgstream.MsgPack, error) {
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: insertMsg.BeginTs(),
|
||||
EndTs: insertMsg.EndTs(),
|
||||
}
|
||||
|
||||
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
|
||||
for channel, rowOffsets := range channel2RowOffsets {
|
||||
partitionName := insertMsg.PartitionName
|
||||
msgs, err := repackInsertDataByPartition(ctx, partitionName, rowOffsets, channel, insertMsg, segIDAssigner)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("repack insert data to msg pack failed",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.String("partition name", partitionName),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgPack.Msgs = append(msgPack.Msgs, msgs...)
|
||||
}
|
||||
|
||||
err := setMsgID(ctx, msgPack.Msgs, idAllocator)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error("failed to set msgID when repack insert data",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.String("partition name", insertMsg.PartitionName),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msgPack, nil
|
||||
}
|
||||
|
||||
func repackInsertDataWithPartitionKey(ctx context.Context,
|
||||
channelNames []string,
|
||||
partitionKeys *schemapb.FieldData,
|
||||
insertMsg *msgstream.InsertMsg,
|
||||
result *milvuspb.MutationResult,
|
||||
idAllocator *allocator.IDAllocator,
|
||||
segIDAssigner *segIDAssigner,
|
||||
) (*msgstream.MsgPack, error) {
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: insertMsg.BeginTs(),
|
||||
EndTs: insertMsg.EndTs(),
|
||||
}
|
||||
|
||||
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
|
||||
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, insertMsg.GetDbName(), insertMsg.CollectionName)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("get default partition names failed in partition key mode",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
hashValues, err := typeutil.HashKey2Partitions(partitionKeys, partitionNames)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("has partition keys to partitions failed",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for channel, rowOffsets := range channel2RowOffsets {
|
||||
partition2RowOffsets := make(map[string][]int)
|
||||
for _, idx := range rowOffsets {
|
||||
partitionName := partitionNames[hashValues[idx]]
|
||||
if _, ok := partition2RowOffsets[partitionName]; !ok {
|
||||
partition2RowOffsets[partitionName] = []int{}
|
||||
}
|
||||
partition2RowOffsets[partitionName] = append(partition2RowOffsets[partitionName], idx)
|
||||
}
|
||||
|
||||
errGroup, _ := errgroup.WithContext(ctx)
|
||||
partition2Msgs := typeutil.NewConcurrentMap[string, []msgstream.TsMsg]()
|
||||
for partitionName, offsets := range partition2RowOffsets {
|
||||
partitionName := partitionName
|
||||
offsets := offsets
|
||||
errGroup.Go(func() error {
|
||||
msgs, err := repackInsertDataByPartition(ctx, partitionName, offsets, channel, insertMsg, segIDAssigner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
partition2Msgs.Insert(partitionName, msgs)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err = errGroup.Wait()
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("repack insert data into insert msg pack failed",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.String("channelName", channel),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
partition2Msgs.Range(func(name string, msgs []msgstream.TsMsg) bool {
|
||||
msgPack.Msgs = append(msgPack.Msgs, msgs...)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
err = setMsgID(ctx, msgPack.Msgs, idAllocator)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error("failed to set msgID when repack insert data",
|
||||
zap.String("collectionName", insertMsg.CollectionName),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return msgPack, nil
|
||||
}
|
||||
|
||||
@ -21,7 +21,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
@ -49,12 +48,6 @@ func TestRepackInsertData(t *testing.T) {
|
||||
defer mix.Close()
|
||||
|
||||
cache := NewMockCache(t)
|
||||
cache.On("GetPartitionID",
|
||||
mock.Anything, // context.Context
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
mock.AnythingOfType("string"),
|
||||
).Return(int64(1), nil)
|
||||
globalMetaCache = cache
|
||||
|
||||
idAllocator, err := allocator.NewIDAllocator(ctx, mix, paramtable.GetNodeID())
|
||||
@ -113,33 +106,6 @@ func TestRepackInsertData(t *testing.T) {
|
||||
for index := range insertMsg.RowIDs {
|
||||
insertMsg.RowIDs[index] = int64(index)
|
||||
}
|
||||
|
||||
ids, err := parsePrimaryFieldData2IDs(fieldData)
|
||||
assert.NoError(t, err)
|
||||
result := &milvuspb.MutationResult{
|
||||
IDs: ids,
|
||||
}
|
||||
|
||||
t.Run("assign segmentID failed", func(t *testing.T) {
|
||||
fakeSegAllocator, err := newSegIDAssigner(ctx, &mockDataCoord2{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
_ = fakeSegAllocator.Start()
|
||||
defer fakeSegAllocator.Close()
|
||||
|
||||
_, err = repackInsertData(ctx, []string{"test_dml_channel"}, insertMsg,
|
||||
result, idAllocator, fakeSegAllocator)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
t.Run("repack insert data success", func(t *testing.T) {
|
||||
_, err = repackInsertData(ctx, []string{"test_dml_channel"}, insertMsg, result, idAllocator, segAllocator)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepackInsertDataWithPartitionKey(t *testing.T) {
|
||||
@ -161,11 +127,6 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
|
||||
_ = idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
|
||||
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
fieldName2Types := map[string]schemapb.DataType{
|
||||
testInt64Field: schemapb.DataType_Int64,
|
||||
testVarCharField: schemapb.DataType_VarChar,
|
||||
@ -221,17 +182,4 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
|
||||
for index := range insertMsg.RowIDs {
|
||||
insertMsg.RowIDs[index] = int64(index)
|
||||
}
|
||||
|
||||
ids, err := parsePrimaryFieldData2IDs(fieldNameToDatas[testInt64Field])
|
||||
assert.NoError(t, err)
|
||||
result := &milvuspb.MutationResult{
|
||||
IDs: ids,
|
||||
}
|
||||
|
||||
t.Run("repack insert data success", func(t *testing.T) {
|
||||
partitionKeys := generateFieldData(schemapb.DataType_VarChar, testVarCharField, nb)
|
||||
_, err = repackInsertDataWithPartitionKey(ctx, []string{"test_dml_channel"}, partitionKeys,
|
||||
insertMsg, result, idAllocator, segAllocator)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@ -21,14 +21,12 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@ -40,19 +38,15 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/expr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/logutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/resource"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
@ -87,7 +81,6 @@ type Proxy struct {
|
||||
|
||||
stateCode atomic.Int32
|
||||
|
||||
etcdCli *clientv3.Client
|
||||
address string
|
||||
mixCoord types.MixCoordClient
|
||||
|
||||
@ -95,23 +88,16 @@ type Proxy struct {
|
||||
|
||||
chMgr channelsMgr
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
|
||||
sched *taskScheduler
|
||||
|
||||
chTicker channelsTimeTicker
|
||||
|
||||
rowIDAllocator *allocator.IDAllocator
|
||||
tsoAllocator *timestampAllocator
|
||||
segAssigner *segIDAssigner
|
||||
|
||||
metricsCacheManager *metricsinfo.MetricsCacheManager
|
||||
|
||||
session *sessionutil.Session
|
||||
shardMgr shardClientMgr
|
||||
|
||||
factory dependency.Factory
|
||||
|
||||
searchResultCh chan *internalpb.SearchResults
|
||||
|
||||
// Add callback functions at different stages
|
||||
@ -122,8 +108,7 @@ type Proxy struct {
|
||||
lbPolicy LBPolicy
|
||||
|
||||
// resource manager
|
||||
resourceManager resource.Manager
|
||||
replicateStreamManager *ReplicateStreamManager
|
||||
resourceManager resource.Manager
|
||||
|
||||
// materialized view
|
||||
enableMaterializedView bool
|
||||
@ -135,7 +120,7 @@ type Proxy struct {
|
||||
}
|
||||
|
||||
// NewProxy returns a Proxy struct.
|
||||
func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
||||
func NewProxy(ctx context.Context, _ dependency.Factory) (*Proxy, error) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
n := 1024 // better to be configurable
|
||||
@ -143,18 +128,15 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
|
||||
lbPolicy := NewLBPolicyImpl(mgr)
|
||||
lbPolicy.Start(ctx)
|
||||
resourceManager := resource.NewManager(10*time.Second, 20*time.Second, make(map[string]time.Duration))
|
||||
replicateStreamManager := NewReplicateStreamManager(ctx, factory, resourceManager)
|
||||
node := &Proxy{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
factory: factory,
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
shardMgr: mgr,
|
||||
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
|
||||
lbPolicy: lbPolicy,
|
||||
resourceManager: resourceManager,
|
||||
replicateStreamManager: replicateStreamManager,
|
||||
slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
shardMgr: mgr,
|
||||
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
|
||||
lbPolicy: lbPolicy,
|
||||
resourceManager: resourceManager,
|
||||
slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
|
||||
}
|
||||
node.UpdateStateCode(commonpb.StateCode_Abnormal)
|
||||
expr.Register("proxy", node)
|
||||
@ -223,10 +205,6 @@ func (node *Proxy) Init() error {
|
||||
}
|
||||
log.Info("init session for Proxy done")
|
||||
|
||||
node.factory.Init(Params)
|
||||
|
||||
log.Debug("init access log for Proxy done")
|
||||
|
||||
err := node.initRateCollector()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -253,44 +231,18 @@ func (node *Proxy) Init() error {
|
||||
node.tsoAllocator = tsoAllocator
|
||||
log.Debug("create timestamp allocator done", zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()))
|
||||
|
||||
segAssigner, err := newSegIDAssigner(node.ctx, node.mixCoord, node.lastTick)
|
||||
if err != nil {
|
||||
log.Warn("failed to create segment id assigner",
|
||||
zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
node.segAssigner = segAssigner
|
||||
node.segAssigner.PeerID = paramtable.GetNodeID()
|
||||
log.Debug("create segment id assigner done", zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()))
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.mixCoord)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, node.factory)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc)
|
||||
node.chMgr = chMgr
|
||||
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
|
||||
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
|
||||
if err != nil {
|
||||
log.Warn("failed to create replicate msg stream",
|
||||
zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
node.replicateMsgStream.ForceEnableProduce(true)
|
||||
node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel})
|
||||
|
||||
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory)
|
||||
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator)
|
||||
if err != nil {
|
||||
log.Warn("failed to create task scheduler", zap.String("role", typeutil.ProxyRole), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("create task scheduler done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
syncTimeTickInterval := Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond) / 2
|
||||
node.chTicker = newChannelsTimeTicker(node.ctx, Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond)/2, []string{}, node.sched.getPChanStatistics, tsoAllocator)
|
||||
log.Debug("create channels time ticker done", zap.String("role", typeutil.ProxyRole), zap.Duration("syncTimeTickInterval", syncTimeTickInterval))
|
||||
|
||||
node.enableComplexDeleteLimit = Params.QuotaConfig.ComplexDeleteLimitEnable.GetAsBool()
|
||||
node.metricsCacheManager = metricsinfo.NewMetricsCacheManager()
|
||||
log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole))
|
||||
@ -314,90 +266,6 @@ func (node *Proxy) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendChannelsTimeTickLoop starts a goroutine that synchronizes the time tick information.
|
||||
func (node *Proxy) sendChannelsTimeTickLoop() {
|
||||
log := log.Ctx(node.ctx)
|
||||
node.wg.Add(1)
|
||||
go func() {
|
||||
defer node.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond))
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-node.ctx.Done():
|
||||
log.Info("send channels time tick loop exit")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !Params.CommonCfg.TTMsgEnabled.GetAsBool() {
|
||||
continue
|
||||
}
|
||||
stats, ts, err := node.chTicker.getMinTsStatistics()
|
||||
if err != nil {
|
||||
log.Warn("sendChannelsTimeTickLoop.getMinTsStatistics", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if ts == 0 {
|
||||
log.Warn("sendChannelsTimeTickLoop.getMinTsStatistics default timestamp equal 0")
|
||||
continue
|
||||
}
|
||||
|
||||
channels := make([]pChan, 0, len(stats))
|
||||
tss := make([]Timestamp, 0, len(stats))
|
||||
|
||||
maxTs := ts
|
||||
for channel, ts := range stats {
|
||||
channels = append(channels, channel)
|
||||
tss = append(tss, ts)
|
||||
if ts > maxTs {
|
||||
maxTs = ts
|
||||
}
|
||||
}
|
||||
|
||||
req := &internalpb.ChannelTimeTickMsg{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_TimeTick),
|
||||
commonpbutil.WithSourceID(node.session.ServerID),
|
||||
),
|
||||
ChannelNames: channels,
|
||||
Timestamps: tss,
|
||||
DefaultTimestamp: maxTs,
|
||||
}
|
||||
|
||||
func() {
|
||||
// we should pay more attention to the max lag.
|
||||
minTs := maxTs
|
||||
minTsOfChannel := "default"
|
||||
|
||||
// find the min ts and the related channel.
|
||||
for channel, ts := range stats {
|
||||
if ts < minTs {
|
||||
minTs = ts
|
||||
minTsOfChannel = channel
|
||||
}
|
||||
}
|
||||
|
||||
sub := tsoutil.SubByNow(minTs)
|
||||
metrics.ProxySyncTimeTickLag.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), minTsOfChannel).Set(float64(sub))
|
||||
}()
|
||||
|
||||
status, err := node.mixCoord.UpdateChannelTimeTick(node.ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if status.GetErrorCode() != 0 {
|
||||
log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick",
|
||||
zap.Any("ErrorCode", status.ErrorCode),
|
||||
zap.Any("Reason", status.Reason))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start starts a proxy node.
|
||||
func (node *Proxy) Start() error {
|
||||
log := log.Ctx(node.ctx)
|
||||
@ -417,22 +285,6 @@ func (node *Proxy) Start() error {
|
||||
}
|
||||
log.Debug("start id allocator done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
if !streamingutil.IsStreamingServiceEnabled() {
|
||||
if err := node.segAssigner.Start(); err != nil {
|
||||
log.Warn("failed to start segment id assigner", zap.String("role", typeutil.ProxyRole), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("start segment id assigner done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
if err := node.chTicker.start(); err != nil {
|
||||
log.Warn("failed to start channels time ticker", zap.String("role", typeutil.ProxyRole), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("start channels time ticker done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
node.sendChannelsTimeTickLoop()
|
||||
}
|
||||
|
||||
// Start callbacks
|
||||
for _, cb := range node.startCallbacks {
|
||||
cb()
|
||||
@ -465,21 +317,6 @@ func (node *Proxy) Stop() error {
|
||||
log.Info("close scheduler", zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if !streamingutil.IsStreamingServiceEnabled() {
|
||||
if node.segAssigner != nil {
|
||||
node.segAssigner.Close()
|
||||
log.Info("close segment id assigner", zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
|
||||
if node.chTicker != nil {
|
||||
err := node.chTicker.close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("close channels time ticker", zap.String("role", typeutil.ProxyRole))
|
||||
}
|
||||
}
|
||||
|
||||
for _, cb := range node.closeCallbacks {
|
||||
cb()
|
||||
}
|
||||
@ -492,10 +329,6 @@ func (node *Proxy) Stop() error {
|
||||
node.shardMgr.Close()
|
||||
}
|
||||
|
||||
if node.chMgr != nil {
|
||||
node.chMgr.removeAllDMLStream()
|
||||
}
|
||||
|
||||
if node.lbPolicy != nil {
|
||||
node.lbPolicy.Close()
|
||||
}
|
||||
@ -519,11 +352,6 @@ func (node *Proxy) AddStartCallback(callbacks ...func()) {
|
||||
node.startCallbacks = append(node.startCallbacks, callbacks...)
|
||||
}
|
||||
|
||||
// lastTick returns the last write timestamp of all pchans in this Proxy.
|
||||
func (node *Proxy) lastTick() Timestamp {
|
||||
return node.chTicker.getMinTick()
|
||||
}
|
||||
|
||||
// AddCloseCallback adds a callback in the Close phase.
|
||||
func (node *Proxy) AddCloseCallback(callbacks ...func()) {
|
||||
node.closeCallbacks = append(node.closeCallbacks, callbacks...)
|
||||
@ -537,11 +365,6 @@ func (node *Proxy) GetAddress() string {
|
||||
return node.address
|
||||
}
|
||||
|
||||
// SetEtcdClient sets etcd client for proxy.
|
||||
func (node *Proxy) SetEtcdClient(client *clientv3.Client) {
|
||||
node.etcdCli = client
|
||||
}
|
||||
|
||||
// SetMixCoordClient sets MixCoord client for proxy.
|
||||
func (node *Proxy) SetMixCoordClient(cli types.MixCoordClient) {
|
||||
node.mixCoord = cli
|
||||
|
||||
@ -23,13 +23,11 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/blang/semver/v4"
|
||||
"github.com/cockroachdb/errors"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@ -50,12 +48,13 @@ import (
|
||||
mixc "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
||||
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
|
||||
"github.com/milvus-io/milvus/internal/util/componentutil"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/internal/util/testutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
@ -64,7 +63,6 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/tracer"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
||||
@ -89,7 +87,6 @@ func init() {
|
||||
Registry = prometheus.NewRegistry()
|
||||
Registry.MustRegister(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{}))
|
||||
Registry.MustRegister(prometheus.NewGoCollector())
|
||||
common.Version = semver.MustParse("2.5.9")
|
||||
}
|
||||
|
||||
func runMixCoord(ctx context.Context, localMsg bool) *grpcmixcoord.Server {
|
||||
@ -119,6 +116,33 @@ func runMixCoord(ctx context.Context, localMsg bool) *grpcmixcoord.Server {
|
||||
return rc
|
||||
}
|
||||
|
||||
func runStreamingNode(ctx context.Context, localMsg bool, alias string) *grpcstreamingnode.Server {
|
||||
var sn *grpcstreamingnode.Server
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
factory := dependency.MockDefaultFactory(localMsg, Params)
|
||||
var err error
|
||||
sn, err = grpcstreamingnode.NewServer(ctx, factory)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err = sn.Prepare(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = sn.Run()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
metrics.RegisterStreamingNode(Registry)
|
||||
return sn
|
||||
}
|
||||
|
||||
func runQueryNode(ctx context.Context, localMsg bool, alias string) *grpcquerynode.Server {
|
||||
var qn *grpcquerynode.Server
|
||||
var wg sync.WaitGroup
|
||||
@ -276,17 +300,9 @@ func TestProxy(t *testing.T) {
|
||||
paramtable.Init()
|
||||
params := paramtable.Get()
|
||||
testutil.ResetEnvironment()
|
||||
|
||||
wal := mock_streaming.NewMockWALAccesser(t)
|
||||
b := mock_streaming.NewMockBroadcast(t)
|
||||
wal.EXPECT().Broadcast().Return(b).Maybe()
|
||||
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil).Maybe()
|
||||
local := mock_streaming.NewMockLocal(t)
|
||||
local.EXPECT().GetLatestMVCCTimestampIfLocal(mock.Anything, mock.Anything).Return(0, nil).Maybe()
|
||||
local.EXPECT().GetMetricsIfLocal(mock.Anything).Return(&types.StreamingNodeMetrics{}, nil).Maybe()
|
||||
wal.EXPECT().Local().Return(local).Maybe()
|
||||
streaming.SetWALForTest(wal)
|
||||
defer streaming.RecoverWALForTest()
|
||||
paramtable.SetLocalComponentEnabled(typeutil.StreamingNodeRole)
|
||||
streamingutil.SetStreamingServiceEnabled()
|
||||
defer streamingutil.UnsetStreamingServiceEnabled()
|
||||
|
||||
params.RootCoordGrpcServerCfg.IP = "localhost"
|
||||
params.QueryCoordGrpcServerCfg.IP = "localhost"
|
||||
@ -295,15 +311,14 @@ func TestProxy(t *testing.T) {
|
||||
params.QueryNodeGrpcServerCfg.IP = "localhost"
|
||||
params.DataNodeGrpcServerCfg.IP = "localhost"
|
||||
params.StreamingNodeGrpcServerCfg.IP = "localhost"
|
||||
|
||||
path := "/tmp/milvus/rocksmq" + funcutil.GenRandomStr()
|
||||
t.Setenv("ROCKSMQ_PATH", path)
|
||||
defer os.RemoveAll(path)
|
||||
params.Save(params.MQCfg.Type.Key, "pulsar")
|
||||
params.CommonCfg.EnableStorageV2.SwapTempValue("false")
|
||||
defer params.CommonCfg.EnableStorageV2.SwapTempValue("")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx = GetContext(ctx, "root:123456")
|
||||
localMsg := true
|
||||
factory := dependency.MockDefaultFactory(localMsg, Params)
|
||||
factory := dependency.NewDefaultFactory(false)
|
||||
alias := "TestProxy"
|
||||
|
||||
log.Info("Initialize parameter table of Proxy")
|
||||
@ -314,11 +329,16 @@ func TestProxy(t *testing.T) {
|
||||
dn := runDataNode(ctx, localMsg, alias)
|
||||
log.Info("running DataNode ...")
|
||||
|
||||
sn := runStreamingNode(ctx, localMsg, alias)
|
||||
log.Info("running StreamingNode ...")
|
||||
|
||||
qn := runQueryNode(ctx, localMsg, alias)
|
||||
log.Info("running QueryNode ...")
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
streaming.Init()
|
||||
|
||||
proxy, err := NewProxy(ctx, factory)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, proxy)
|
||||
@ -331,9 +351,11 @@ func TestProxy(t *testing.T) {
|
||||
Params.EtcdCfg.EtcdTLSKey.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
|
||||
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer etcdcli.Close()
|
||||
assert.NoError(t, err)
|
||||
proxy.SetEtcdClient(etcdcli)
|
||||
|
||||
testServer := newProxyTestServer(proxy)
|
||||
wg.Add(1)
|
||||
@ -372,7 +394,7 @@ func TestProxy(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
log.Info("Register proxy done")
|
||||
defer func() {
|
||||
a := []any{mix, qn, dn, proxy}
|
||||
a := []any{mix, qn, dn, sn, proxy}
|
||||
fmt.Println(len(a))
|
||||
// HINT: the order of stopping service refers to the `roles.go` file
|
||||
log.Info("start to stop the services")
|
||||
@ -1106,6 +1128,10 @@ func TestProxy(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||
segmentIDs = resp.CollSegIDs[collectionName].Data
|
||||
// TODO: Here's a Bug, because a growing segment may cannot be seen right away by mixcoord,
|
||||
// it can only be seen by streamingnode right away, so we need to check the flush state at streamingnode but not here.
|
||||
// use timetick for GetFlushState in-future but not segment list.
|
||||
time.Sleep(5 * time.Second)
|
||||
log.Info("flush collection", zap.Int64s("segments to be flushed", segmentIDs))
|
||||
|
||||
f := func() bool {
|
||||
@ -4792,11 +4818,7 @@ func TestProxy_Import(t *testing.T) {
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
|
||||
wal := mock_streaming.NewMockWALAccesser(t)
|
||||
b := mock_streaming.NewMockBroadcast(t)
|
||||
wal.EXPECT().Broadcast().Return(b).Maybe()
|
||||
streaming.SetWALForTest(wal)
|
||||
defer streaming.RecoverWALForTest()
|
||||
streaming.SetupNoopWALForTest()
|
||||
|
||||
t.Run("Import failed", func(t *testing.T) {
|
||||
proxy := &Proxy{}
|
||||
@ -4825,7 +4847,6 @@ func TestProxy_Import(t *testing.T) {
|
||||
chMgr.EXPECT().getVChannels(mock.Anything).Return([]string{"foo"}, nil)
|
||||
proxy.chMgr = chMgr
|
||||
|
||||
factory := dependency.NewDefaultFactory(true)
|
||||
rc := mocks.NewMockRootCoordClient(t)
|
||||
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
|
||||
ID: rand.Int63(),
|
||||
@ -4839,19 +4860,12 @@ func TestProxy_Import(t *testing.T) {
|
||||
proxy.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
scheduler, err := newTaskScheduler(ctx, proxy.tsoAllocator, factory)
|
||||
scheduler, err := newTaskScheduler(ctx, proxy.tsoAllocator)
|
||||
assert.NoError(t, err)
|
||||
proxy.sched = scheduler
|
||||
err = proxy.sched.Start()
|
||||
assert.NoError(t, err)
|
||||
|
||||
wal := mock_streaming.NewMockWALAccesser(t)
|
||||
b := mock_streaming.NewMockBroadcast(t)
|
||||
wal.EXPECT().Broadcast().Return(b)
|
||||
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil)
|
||||
streaming.SetWALForTest(wal)
|
||||
defer streaming.RecoverWALForTest()
|
||||
|
||||
req := &milvuspb.ImportRequest{
|
||||
CollectionName: "dummy",
|
||||
Files: []string{"a.json"},
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/resource"
|
||||
)
|
||||
|
||||
const (
|
||||
ReplicateMsgStreamTyp = "replicate_msg_stream"
|
||||
ReplicateMsgStreamExpireTime = 30 * time.Second
|
||||
)
|
||||
|
||||
type ReplicateStreamManager struct {
|
||||
ctx context.Context
|
||||
factory msgstream.Factory
|
||||
dispatcher msgstream.UnmarshalDispatcher
|
||||
resourceManager resource.Manager
|
||||
}
|
||||
|
||||
func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, resourceManager resource.Manager) *ReplicateStreamManager {
|
||||
manager := &ReplicateStreamManager{
|
||||
ctx: ctx,
|
||||
factory: factory,
|
||||
dispatcher: (&msgstream.ProtoUDFactory{}).NewUnmarshalDispatcher(),
|
||||
resourceManager: resourceManager,
|
||||
}
|
||||
return manager
|
||||
}
|
||||
|
||||
func (m *ReplicateStreamManager) newMsgStreamResource(ctx context.Context, channel string) resource.NewResourceFunc {
|
||||
return func() (resource.Resource, error) {
|
||||
msgStream, err := m.factory.NewMsgStream(ctx)
|
||||
if err != nil {
|
||||
log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
msgStream.SetRepackFunc(replicatePackFunc)
|
||||
msgStream.AsProducer(ctx, []string{channel})
|
||||
msgStream.ForceEnableProduce(true)
|
||||
|
||||
res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() {
|
||||
msgStream.Close()
|
||||
})
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) {
|
||||
ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel))
|
||||
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(ctx, channel))
|
||||
if err != nil {
|
||||
ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if obj, ok := res.Get().(msgstream.MsgStream); ok && obj != nil {
|
||||
return obj, nil
|
||||
}
|
||||
ctxLog.Warn("invalid resource object", zap.Any("obj", res.Get()))
|
||||
return nil, merr.ErrInvalidStreamObj
|
||||
}
|
||||
|
||||
func (m *ReplicateStreamManager) GetMsgDispatcher() msgstream.UnmarshalDispatcher {
|
||||
return m.dispatcher
|
||||
}
|
||||
@ -1,85 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/resource"
|
||||
)
|
||||
|
||||
func TestReplicateManager(t *testing.T) {
|
||||
factory := newMockMsgStreamFactory()
|
||||
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
|
||||
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
|
||||
|
||||
{
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return nil, errors.New("mock msgstream fail")
|
||||
}
|
||||
_, err := manager.GetReplicateMsgStream(context.Background(), "test")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
{
|
||||
mockMsgStream := newMockMsgStream()
|
||||
i := 0
|
||||
mockMsgStream.setRepack = func(repackFunc msgstream.RepackFunc) {
|
||||
i++
|
||||
}
|
||||
mockMsgStream.asProducer = func(producers []string) {
|
||||
i++
|
||||
}
|
||||
mockMsgStream.forceEnableProduce = func(b bool) {
|
||||
i++
|
||||
}
|
||||
mockMsgStream.close = func() {
|
||||
i++
|
||||
}
|
||||
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
|
||||
return mockMsgStream, nil
|
||||
}
|
||||
_, err := manager.GetReplicateMsgStream(context.Background(), "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, i)
|
||||
time.Sleep(time.Second)
|
||||
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, i)
|
||||
res := resourceManager.Delete(ReplicateMsgStreamTyp, "test")
|
||||
assert.NotNil(t, res)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return resourceManager.Delete(ReplicateMsgStreamTyp, "test") == nil
|
||||
}, time.Second*4, time.Millisecond*500)
|
||||
|
||||
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 7, i)
|
||||
}
|
||||
{
|
||||
res := resourceManager.Delete(ReplicateMsgStreamTyp, "test")
|
||||
assert.NotNil(t, res)
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return resourceManager.Delete(ReplicateMsgStreamTyp, "test") == nil
|
||||
}, time.Second*4, time.Millisecond*500)
|
||||
|
||||
res, err := resourceManager.Get(ReplicateMsgStreamTyp, "test", func() (resource.Resource, error) {
|
||||
return resource.NewResource(resource.WithObj("str")), nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "str", res.Get())
|
||||
|
||||
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
|
||||
assert.ErrorIs(t, err, merr.ErrInvalidStreamObj)
|
||||
}
|
||||
|
||||
{
|
||||
assert.NotNil(t, manager.GetMsgDispatcher())
|
||||
}
|
||||
}
|
||||
@ -1,401 +0,0 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
segCountPerRPC = 20000
|
||||
)
|
||||
|
||||
// DataCoord is a narrowed interface of DataCoordinator which only provide AssignSegmentID method
|
||||
type DataCoord interface {
|
||||
AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error)
|
||||
}
|
||||
|
||||
type segRequest struct {
|
||||
allocator.BaseRequest
|
||||
count uint32
|
||||
collID UniqueID
|
||||
partitionID UniqueID
|
||||
segInfo map[UniqueID]uint32
|
||||
channelName string
|
||||
timestamp Timestamp
|
||||
}
|
||||
|
||||
type segInfo struct {
|
||||
segID UniqueID
|
||||
count uint32
|
||||
expireTime Timestamp
|
||||
}
|
||||
|
||||
type assignInfo struct {
|
||||
collID UniqueID
|
||||
partitionID UniqueID
|
||||
channelName string
|
||||
segInfos *list.List
|
||||
lastInsertTime time.Time
|
||||
}
|
||||
|
||||
func (info *segInfo) IsExpired(ts Timestamp) bool {
|
||||
return ts > info.expireTime || info.count <= 0
|
||||
}
|
||||
|
||||
func (info *segInfo) Capacity(ts Timestamp) uint32 {
|
||||
if info.IsExpired(ts) {
|
||||
return 0
|
||||
}
|
||||
return info.count
|
||||
}
|
||||
|
||||
func (info *segInfo) Assign(ts Timestamp, count uint32) uint32 {
|
||||
if info.IsExpired(ts) {
|
||||
log.Ctx(context.TODO()).Debug("segInfo Assign IsExpired", zap.Uint64("ts", ts),
|
||||
zap.Uint32("count", count))
|
||||
return 0
|
||||
}
|
||||
ret := uint32(0)
|
||||
if info.count >= count {
|
||||
info.count -= count
|
||||
ret = count
|
||||
} else {
|
||||
ret = info.count
|
||||
info.count = 0
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (info *assignInfo) RemoveExpired(ts Timestamp) {
|
||||
var next *list.Element
|
||||
for e := info.segInfos.Front(); e != nil; e = next {
|
||||
next = e.Next()
|
||||
segInfo, ok := e.Value.(*segInfo)
|
||||
if !ok {
|
||||
log.Warn("can not cast to segInfo")
|
||||
continue
|
||||
}
|
||||
if segInfo.IsExpired(ts) {
|
||||
info.segInfos.Remove(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (info *assignInfo) Capacity(ts Timestamp) uint32 {
|
||||
ret := uint32(0)
|
||||
for e := info.segInfos.Front(); e != nil; e = e.Next() {
|
||||
segInfo := e.Value.(*segInfo)
|
||||
ret += segInfo.Capacity(ts)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (info *assignInfo) Assign(ts Timestamp, count uint32) (map[UniqueID]uint32, error) {
|
||||
capacity := info.Capacity(ts)
|
||||
if capacity < count {
|
||||
errMsg := fmt.Sprintf("AssignSegment Failed: capacity:%d is less than count:%d", capacity, count)
|
||||
return nil, errors.New(errMsg)
|
||||
}
|
||||
|
||||
result := make(map[UniqueID]uint32)
|
||||
for e := info.segInfos.Front(); e != nil && count != 0; e = e.Next() {
|
||||
segInfo := e.Value.(*segInfo)
|
||||
cur := segInfo.Assign(ts, count)
|
||||
count -= cur
|
||||
if cur > 0 {
|
||||
result[segInfo.segID] += cur
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type segIDAssigner struct {
|
||||
allocator.CachedAllocator
|
||||
assignInfos map[UniqueID]*list.List // collectionID -> *list.List
|
||||
segReqs []*datapb.SegmentIDRequest
|
||||
getTickFunc func() Timestamp
|
||||
PeerID UniqueID
|
||||
|
||||
dataCoord DataCoord
|
||||
countPerRPC uint32
|
||||
}
|
||||
|
||||
// newSegIDAssigner creates a new segIDAssigner
|
||||
func newSegIDAssigner(ctx context.Context, dataCoord DataCoord, getTickFunc func() Timestamp) (*segIDAssigner, error) {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
sa := &segIDAssigner{
|
||||
CachedAllocator: allocator.CachedAllocator{
|
||||
Ctx: ctx1,
|
||||
CancelFunc: cancel,
|
||||
Role: "SegmentIDAllocator",
|
||||
},
|
||||
countPerRPC: segCountPerRPC,
|
||||
dataCoord: dataCoord,
|
||||
assignInfos: make(map[UniqueID]*list.List),
|
||||
getTickFunc: getTickFunc,
|
||||
}
|
||||
sa.TChan = &allocator.Ticker{
|
||||
UpdateInterval: time.Second,
|
||||
}
|
||||
sa.CachedAllocator.SyncFunc = sa.syncSegments
|
||||
sa.CachedAllocator.ProcessFunc = sa.processFunc
|
||||
sa.CachedAllocator.CheckSyncFunc = sa.checkSyncFunc
|
||||
sa.CachedAllocator.PickCanDoFunc = sa.pickCanDoFunc
|
||||
sa.Init()
|
||||
return sa, nil
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) collectExpired() {
|
||||
ts := sa.getTickFunc()
|
||||
var next *list.Element
|
||||
for _, info := range sa.assignInfos {
|
||||
for e := info.Front(); e != nil; e = next {
|
||||
next = e.Next()
|
||||
assign := e.Value.(*assignInfo)
|
||||
assign.RemoveExpired(ts)
|
||||
if assign.Capacity(ts) == 0 {
|
||||
info.Remove(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) pickCanDoFunc() {
|
||||
if sa.ToDoReqs == nil {
|
||||
return
|
||||
}
|
||||
records := make(map[UniqueID]map[UniqueID]map[string]uint32)
|
||||
var newTodoReqs []allocator.Request
|
||||
for _, req := range sa.ToDoReqs {
|
||||
segRequest := req.(*segRequest)
|
||||
collID := segRequest.collID
|
||||
partitionID := segRequest.partitionID
|
||||
channelName := segRequest.channelName
|
||||
|
||||
if _, ok := records[collID]; !ok {
|
||||
records[collID] = make(map[UniqueID]map[string]uint32)
|
||||
}
|
||||
if _, ok := records[collID][partitionID]; !ok {
|
||||
records[collID][partitionID] = make(map[string]uint32)
|
||||
}
|
||||
|
||||
if _, ok := records[collID][partitionID][channelName]; !ok {
|
||||
records[collID][partitionID][channelName] = 0
|
||||
}
|
||||
|
||||
records[collID][partitionID][channelName] += segRequest.count
|
||||
assign, err := sa.getAssign(segRequest.collID, segRequest.partitionID, segRequest.channelName)
|
||||
if err != nil || assign.Capacity(segRequest.timestamp) < records[collID][partitionID][channelName] {
|
||||
sa.segReqs = append(sa.segReqs, &datapb.SegmentIDRequest{
|
||||
ChannelName: channelName,
|
||||
Count: segRequest.count,
|
||||
CollectionID: collID,
|
||||
PartitionID: partitionID,
|
||||
})
|
||||
newTodoReqs = append(newTodoReqs, req)
|
||||
} else {
|
||||
sa.CanDoReqs = append(sa.CanDoReqs, req)
|
||||
}
|
||||
}
|
||||
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner pickCanDoFunc", zap.Any("records", records),
|
||||
zap.Int("len(newTodoReqs)", len(newTodoReqs)),
|
||||
zap.Int("len(CanDoReqs)", len(sa.CanDoReqs)))
|
||||
sa.ToDoReqs = newTodoReqs
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) getAssign(collID UniqueID, partitionID UniqueID, channelName string) (*assignInfo, error) {
|
||||
assignInfos, ok := sa.assignInfos[collID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can not find collection %d", collID)
|
||||
}
|
||||
|
||||
for e := assignInfos.Front(); e != nil; e = e.Next() {
|
||||
info := e.Value.(*assignInfo)
|
||||
if info.partitionID != partitionID || info.channelName != channelName {
|
||||
continue
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
return nil, fmt.Errorf("can not find assign info with collID %d, partitionID %d, channelName %s",
|
||||
collID, partitionID, channelName)
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) checkSyncFunc(timeout bool) bool {
|
||||
sa.collectExpired()
|
||||
return timeout || len(sa.segReqs) != 0
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) checkSegReqEqual(req1, req2 *datapb.SegmentIDRequest) bool {
|
||||
if req1 == nil || req2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if req1 == req2 {
|
||||
return true
|
||||
}
|
||||
return req1.CollectionID == req2.CollectionID && req1.PartitionID == req2.PartitionID && req1.ChannelName == req2.ChannelName
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) reduceSegReqs() {
|
||||
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs", zap.Int("len(segReqs)", len(sa.segReqs)))
|
||||
if len(sa.segReqs) == 0 {
|
||||
return
|
||||
}
|
||||
beforeCnt := uint32(0)
|
||||
var newSegReqs []*datapb.SegmentIDRequest
|
||||
for _, req1 := range sa.segReqs {
|
||||
if req1.Count == 0 {
|
||||
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs hit perCount == 0")
|
||||
req1.Count = sa.countPerRPC
|
||||
}
|
||||
beforeCnt += req1.Count
|
||||
var req2 *datapb.SegmentIDRequest
|
||||
for _, req3 := range newSegReqs {
|
||||
if sa.checkSegReqEqual(req1, req3) {
|
||||
req2 = req3
|
||||
break
|
||||
}
|
||||
}
|
||||
if req2 == nil { // not found
|
||||
newSegReqs = append(newSegReqs, req1)
|
||||
} else {
|
||||
req2.Count += req1.Count
|
||||
}
|
||||
}
|
||||
afterCnt := uint32(0)
|
||||
for _, req := range newSegReqs {
|
||||
afterCnt += req.Count
|
||||
}
|
||||
sa.segReqs = newSegReqs
|
||||
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs after reduce", zap.Int("len(segReqs)", len(sa.segReqs)),
|
||||
zap.Uint32("BeforeCnt", beforeCnt),
|
||||
zap.Uint32("AfterCnt", afterCnt))
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) syncSegments() (bool, error) {
|
||||
if len(sa.segReqs) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
sa.reduceSegReqs()
|
||||
req := &datapb.AssignSegmentIDRequest{
|
||||
NodeID: sa.PeerID,
|
||||
PeerRole: typeutil.ProxyRole,
|
||||
SegmentIDRequests: sa.segReqs,
|
||||
}
|
||||
metrics.ProxySyncSegmentRequestLength.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(len(sa.segReqs)))
|
||||
sa.segReqs = nil
|
||||
|
||||
log.Ctx(context.TODO()).Debug("syncSegments call dataCoord.AssignSegmentID", zap.Stringer("request", req))
|
||||
|
||||
resp, err := sa.dataCoord.AssignSegmentID(context.Background(), req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("syncSegmentID Failed:%w", err)
|
||||
}
|
||||
|
||||
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return false, fmt.Errorf("syncSegmentID Failed:%s", resp.GetStatus().GetReason())
|
||||
}
|
||||
|
||||
var errMsg string
|
||||
now := time.Now()
|
||||
success := true
|
||||
for _, segAssign := range resp.SegIDAssignments {
|
||||
if segAssign.Status.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Ctx(context.TODO()).Warn("proxy", zap.String("SyncSegment Error", segAssign.GetStatus().GetReason()))
|
||||
errMsg += segAssign.GetStatus().GetReason()
|
||||
errMsg += "\n"
|
||||
success = false
|
||||
continue
|
||||
}
|
||||
assign, err := sa.getAssign(segAssign.CollectionID, segAssign.PartitionID, segAssign.ChannelName)
|
||||
segInfo2 := &segInfo{
|
||||
segID: segAssign.SegID,
|
||||
count: segAssign.Count,
|
||||
expireTime: segAssign.ExpireTime,
|
||||
}
|
||||
if err != nil {
|
||||
colInfos, ok := sa.assignInfos[segAssign.CollectionID]
|
||||
if !ok {
|
||||
colInfos = list.New()
|
||||
}
|
||||
segInfos := list.New()
|
||||
|
||||
segInfos.PushBack(segInfo2)
|
||||
assign = &assignInfo{
|
||||
collID: segAssign.CollectionID,
|
||||
partitionID: segAssign.PartitionID,
|
||||
channelName: segAssign.ChannelName,
|
||||
segInfos: segInfos,
|
||||
}
|
||||
colInfos.PushBack(assign)
|
||||
sa.assignInfos[segAssign.CollectionID] = colInfos
|
||||
} else {
|
||||
assign.segInfos.PushBack(segInfo2)
|
||||
}
|
||||
assign.lastInsertTime = now
|
||||
}
|
||||
if !success {
|
||||
return false, errors.New(errMsg)
|
||||
}
|
||||
return success, nil
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) processFunc(req allocator.Request) error {
|
||||
segRequest := req.(*segRequest)
|
||||
assign, err := sa.getAssign(segRequest.collID, segRequest.partitionID, segRequest.channelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err2 := assign.Assign(segRequest.timestamp, segRequest.count)
|
||||
segRequest.segInfo = result
|
||||
return err2
|
||||
}
|
||||
|
||||
func (sa *segIDAssigner) GetSegmentID(collID UniqueID, partitionID UniqueID, channelName string, count uint32, ts Timestamp) (map[UniqueID]uint32, error) {
|
||||
req := &segRequest{
|
||||
BaseRequest: allocator.BaseRequest{Done: make(chan error), Valid: false},
|
||||
collID: collID,
|
||||
partitionID: partitionID,
|
||||
channelName: channelName,
|
||||
count: count,
|
||||
timestamp: ts,
|
||||
}
|
||||
sa.Reqs <- req
|
||||
if err := req.Wait(); err != nil {
|
||||
return nil, fmt.Errorf("getSegmentID failed: %s", err)
|
||||
}
|
||||
|
||||
return req.segInfo, nil
|
||||
}
|
||||
@ -1,307 +0,0 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
type mockDataCoord struct {
|
||||
expireTime Timestamp
|
||||
}
|
||||
|
||||
func (mockD *mockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
|
||||
assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests))
|
||||
maxPerCnt := 100
|
||||
for _, r := range req.SegmentIDRequests {
|
||||
totalCnt := uint32(0)
|
||||
for totalCnt != r.Count {
|
||||
cnt := uint32(rand.Intn(maxPerCnt))
|
||||
if totalCnt+cnt > r.Count {
|
||||
cnt = r.Count - totalCnt
|
||||
}
|
||||
totalCnt += cnt
|
||||
result := &datapb.SegmentIDAssignment{
|
||||
SegID: 1,
|
||||
ChannelName: r.ChannelName,
|
||||
Count: cnt,
|
||||
CollectionID: r.CollectionID,
|
||||
PartitionID: r.PartitionID,
|
||||
ExpireTime: mockD.expireTime,
|
||||
|
||||
Status: merr.Success(),
|
||||
}
|
||||
assigns = append(assigns, result)
|
||||
}
|
||||
}
|
||||
|
||||
return &datapb.AssignSegmentIDResponse{
|
||||
Status: merr.Success(),
|
||||
SegIDAssignments: assigns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockDataCoord2 struct {
|
||||
expireTime Timestamp
|
||||
}
|
||||
|
||||
func (mockD *mockDataCoord2) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
|
||||
return &datapb.AssignSegmentIDResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "Just For Test",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getLastTick1() Timestamp {
|
||||
return 1000
|
||||
}
|
||||
|
||||
func TestSegmentAllocator1(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dataCoord := &mockDataCoord{}
|
||||
dataCoord.expireTime = Timestamp(1000)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
wg := &sync.WaitGroup{}
|
||||
segAllocator.Start()
|
||||
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
segAllocator.Close()
|
||||
}(wg)
|
||||
total := uint32(0)
|
||||
collNames := []string{"abc", "cba"}
|
||||
for i := 0; i < 10; i++ {
|
||||
colName := collNames[i%2]
|
||||
ret, err := segAllocator.GetSegmentID(1, 1, colName, 1, 1)
|
||||
assert.NoError(t, err)
|
||||
total += ret[1]
|
||||
}
|
||||
assert.Equal(t, uint32(10), total)
|
||||
|
||||
ret, err := segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, 999)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(segCountPerRPC-10), ret[1])
|
||||
|
||||
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 1001)
|
||||
assert.Error(t, err)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
var curLastTick2 = Timestamp(200)
|
||||
|
||||
var curLastTIck2Lock sync.Mutex
|
||||
|
||||
func getLastTick2() Timestamp {
|
||||
curLastTIck2Lock.Lock()
|
||||
defer curLastTIck2Lock.Unlock()
|
||||
curLastTick2 += 100
|
||||
return curLastTick2
|
||||
}
|
||||
|
||||
func TestSegmentAllocator2(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dataCoord := &mockDataCoord{}
|
||||
dataCoord.expireTime = Timestamp(500)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
total := uint32(0)
|
||||
for i := 0; i < 10; i++ {
|
||||
ret, err := segAllocator.GetSegmentID(1, 1, "abc", 1, 200)
|
||||
assert.NoError(t, err)
|
||||
total += ret[1]
|
||||
}
|
||||
assert.Equal(t, uint32(10), total)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_, err = segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, getLastTick2())
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSegmentAllocator3(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dataCoord := &mockDataCoord2{}
|
||||
dataCoord.expireTime = Timestamp(500)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
|
||||
assert.NoError(t, err)
|
||||
wg := &sync.WaitGroup{}
|
||||
segAllocator.Start()
|
||||
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
segAllocator.Close()
|
||||
}(wg)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
|
||||
assert.Error(t, err)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type mockDataCoord3 struct {
|
||||
expireTime Timestamp
|
||||
}
|
||||
|
||||
func (mockD *mockDataCoord3) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
|
||||
assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests))
|
||||
for i, r := range req.SegmentIDRequests {
|
||||
errCode := commonpb.ErrorCode_Success
|
||||
reason := ""
|
||||
if i == 0 {
|
||||
errCode = commonpb.ErrorCode_UnexpectedError
|
||||
reason = "Just for test"
|
||||
}
|
||||
result := &datapb.SegmentIDAssignment{
|
||||
SegID: 1,
|
||||
ChannelName: r.ChannelName,
|
||||
Count: r.Count,
|
||||
CollectionID: r.CollectionID,
|
||||
PartitionID: r.PartitionID,
|
||||
ExpireTime: mockD.expireTime,
|
||||
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: errCode,
|
||||
Reason: reason,
|
||||
},
|
||||
}
|
||||
assigns = append(assigns, result)
|
||||
}
|
||||
|
||||
return &datapb.AssignSegmentIDResponse{
|
||||
Status: merr.Success(),
|
||||
SegIDAssignments: assigns,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestSegmentAllocator4(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dataCoord := &mockDataCoord3{}
|
||||
dataCoord.expireTime = Timestamp(500)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
|
||||
assert.NoError(t, err)
|
||||
wg := &sync.WaitGroup{}
|
||||
segAllocator.Start()
|
||||
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
segAllocator.Close()
|
||||
}(wg)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
|
||||
assert.Error(t, err)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type mockDataCoord5 struct {
|
||||
expireTime Timestamp
|
||||
}
|
||||
|
||||
func (mockD *mockDataCoord5) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
|
||||
return &datapb.AssignSegmentIDResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "Just For Test",
|
||||
},
|
||||
}, errors.New("just for test")
|
||||
}
|
||||
|
||||
func TestSegmentAllocator5(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dataCoord := &mockDataCoord5{}
|
||||
dataCoord.expireTime = Timestamp(500)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
|
||||
assert.NoError(t, err)
|
||||
wg := &sync.WaitGroup{}
|
||||
segAllocator.Start()
|
||||
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
segAllocator.Close()
|
||||
}(wg)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
|
||||
assert.Error(t, err)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSegmentAllocator6(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
dataCoord := &mockDataCoord{}
|
||||
dataCoord.expireTime = Timestamp(500)
|
||||
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
|
||||
assert.NoError(t, err)
|
||||
wg := &sync.WaitGroup{}
|
||||
segAllocator.Start()
|
||||
|
||||
wg.Add(1)
|
||||
go func(group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
segAllocator.Close()
|
||||
}(wg)
|
||||
success := true
|
||||
var sucLock sync.Mutex
|
||||
collNames := []string{"abc", "cba"}
|
||||
reqFunc := func(i int, group *sync.WaitGroup) {
|
||||
defer group.Done()
|
||||
sucLock.Lock()
|
||||
defer sucLock.Unlock()
|
||||
if !success {
|
||||
return
|
||||
}
|
||||
colName := collNames[i%2]
|
||||
count := uint32(10)
|
||||
if i == 0 {
|
||||
count = 0
|
||||
}
|
||||
_, err = segAllocator.GetSegmentID(1, 1, colName, count, 100)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
success = false
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go reqFunc(i, wg)
|
||||
}
|
||||
wg.Wait()
|
||||
assert.True(t, success)
|
||||
}
|
||||
@ -587,7 +587,6 @@ type dropCollectionTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
chMgr channelsMgr
|
||||
chTicker channelsTimeTicker
|
||||
}
|
||||
|
||||
func (t *dropCollectionTask) TraceCtx() context.Context {
|
||||
@ -1047,10 +1046,9 @@ type alterCollectionTask struct {
|
||||
baseTask
|
||||
Condition
|
||||
*milvuspb.AlterCollectionRequest
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
}
|
||||
|
||||
func (t *alterCollectionTask) TraceCtx() context.Context {
|
||||
@ -1266,7 +1264,6 @@ func (t *alterCollectionTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterCollectionRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1278,10 +1275,9 @@ type alterCollectionFieldTask struct {
|
||||
baseTask
|
||||
Condition
|
||||
*milvuspb.AlterCollectionFieldRequest
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
}
|
||||
|
||||
func (t *alterCollectionFieldTask) TraceCtx() context.Context {
|
||||
@ -1495,7 +1491,6 @@ func (t *alterCollectionFieldTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterCollectionFieldRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1938,8 +1933,7 @@ type loadCollectionTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
collectionID UniqueID
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
collectionID UniqueID
|
||||
}
|
||||
|
||||
func (t *loadCollectionTask) TraceCtx() context.Context {
|
||||
@ -2088,7 +2082,6 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return fmt.Errorf("call query coordinator LoadCollection: %s", err)
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadCollectionRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2111,8 +2104,7 @@ type releaseCollectionTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
collectionID UniqueID
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
collectionID UniqueID
|
||||
}
|
||||
|
||||
func (t *releaseCollectionTask) TraceCtx() context.Context {
|
||||
@ -2186,7 +2178,6 @@ func (t *releaseCollectionTask) Execute(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleaseCollectionRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2202,8 +2193,7 @@ type loadPartitionsTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
collectionID UniqueID
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
collectionID UniqueID
|
||||
}
|
||||
|
||||
func (t *loadPartitionsTask) TraceCtx() context.Context {
|
||||
@ -2362,7 +2352,6 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadPartitionsRequest)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -2379,8 +2368,7 @@ type releasePartitionsTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
collectionID UniqueID
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
collectionID UniqueID
|
||||
}
|
||||
|
||||
func (t *releasePartitionsTask) TraceCtx() context.Context {
|
||||
@ -2469,7 +2457,6 @@ func (t *releasePartitionsTask) Execute(ctx context.Context) (err error) {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleasePartitionsRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -22,7 +22,6 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
@ -33,10 +32,9 @@ type CreateAliasTask struct {
|
||||
baseTask
|
||||
Condition
|
||||
*milvuspb.CreateAliasRequest
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
}
|
||||
|
||||
// TraceCtx returns the trace context of the task.
|
||||
@ -111,7 +109,6 @@ func (t *CreateAliasTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.CreateAliasRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -125,10 +122,9 @@ type DropAliasTask struct {
|
||||
baseTask
|
||||
Condition
|
||||
*milvuspb.DropAliasRequest
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
}
|
||||
|
||||
// TraceCtx returns the context for trace
|
||||
@ -189,7 +185,6 @@ func (t *DropAliasTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.DropAliasRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -202,10 +197,9 @@ type AlterAliasTask struct {
|
||||
baseTask
|
||||
Condition
|
||||
*milvuspb.AlterAliasRequest
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
result *commonpb.Status
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
}
|
||||
|
||||
func (t *AlterAliasTask) TraceCtx() context.Context {
|
||||
@ -270,7 +264,6 @@ func (t *AlterAliasTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterAliasRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
@ -25,8 +24,6 @@ type createDatabaseTask struct {
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func (cdt *createDatabaseTask) TraceCtx() context.Context {
|
||||
@ -78,9 +75,6 @@ func (cdt *createDatabaseTask) Execute(ctx context.Context) error {
|
||||
var err error
|
||||
cdt.result, err = cdt.mixCoord.CreateDatabase(ctx, cdt.CreateDatabaseRequest)
|
||||
err = merr.CheckRPCCall(cdt.result, err)
|
||||
if err == nil {
|
||||
SendReplicateMessagePack(ctx, cdt.replicateMsgStream, cdt.CreateDatabaseRequest)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@ -95,8 +89,6 @@ type dropDatabaseTask struct {
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func (ddt *dropDatabaseTask) TraceCtx() context.Context {
|
||||
@ -151,7 +143,6 @@ func (ddt *dropDatabaseTask) Execute(ctx context.Context) error {
|
||||
err = merr.CheckRPCCall(ddt.result, err)
|
||||
if err == nil {
|
||||
globalMetaCache.RemoveDatabase(ctx, ddt.DbName)
|
||||
SendReplicateMessagePack(ctx, ddt.replicateMsgStream, ddt.DropDatabaseRequest)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -230,8 +221,6 @@ type alterDatabaseTask struct {
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func (t *alterDatabaseTask) TraceCtx() context.Context {
|
||||
@ -323,8 +312,6 @@ func (t *alterDatabaseTask) Execute(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterDatabaseRequest)
|
||||
t.result = ret
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2,13 +2,11 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/protobuf/proto"
|
||||
@ -21,7 +19,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
@ -133,60 +130,6 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute")
|
||||
defer sp.End()
|
||||
|
||||
if len(dt.req.GetExpr()) == 0 {
|
||||
return merr.WrapErrParameterInvalid("valid expr", "empty expr", "invalid expression")
|
||||
}
|
||||
|
||||
dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
|
||||
stream, err := dt.chMgr.getOrCreateDmlStream(ctx, dt.collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, numRows, err := repackDeleteMsgByHash(
|
||||
ctx,
|
||||
dt.primaryKeys, dt.vChannels,
|
||||
dt.idAllocator, dt.ts,
|
||||
dt.collectionID, dt.req.GetCollectionName(),
|
||||
dt.partitionID, dt.req.GetPartitionName(),
|
||||
dt.req.GetDbName(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send delete request to log broker
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: dt.BeginTs(),
|
||||
EndTs: dt.EndTs(),
|
||||
}
|
||||
|
||||
for _, msgs := range result {
|
||||
for _, msg := range msgs {
|
||||
msgPack.Msgs = append(msgPack.Msgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug("send delete request to virtual channels",
|
||||
zap.String("collectionName", dt.req.GetCollectionName()),
|
||||
zap.Int64("collectionID", dt.collectionID),
|
||||
zap.Strings("virtual_channels", dt.vChannels),
|
||||
zap.Int64("taskID", dt.ID()),
|
||||
zap.Duration("prepare duration", dt.tr.RecordSpan()))
|
||||
|
||||
err = stream.Produce(ctx, msgPack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dt.sessionTS = dt.ts
|
||||
dt.count += numRows
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dt *deleteTask) PostExecute(ctx context.Context) error {
|
||||
metrics.ProxyDeleteVectors.WithLabelValues(
|
||||
paramtable.GetStringNodeID(),
|
||||
@ -288,7 +231,6 @@ type deleteRunner struct {
|
||||
|
||||
// channel
|
||||
chMgr channelsMgr
|
||||
chTicker channelsTimeTicker
|
||||
vChannels []vChan
|
||||
|
||||
idAllocator allocator.Interface
|
||||
@ -437,20 +379,13 @@ func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs,
|
||||
req: dr.req,
|
||||
idAllocator: dr.idAllocator,
|
||||
chMgr: dr.chMgr,
|
||||
chTicker: dr.chTicker,
|
||||
collectionID: dr.collectionID,
|
||||
partitionID: partitionID,
|
||||
vChannels: dr.vChannels,
|
||||
primaryKeys: primaryKeys,
|
||||
dbID: dr.dbID,
|
||||
}
|
||||
|
||||
var enqueuedTask task = dt
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
enqueuedTask = &deleteTaskByStreamingService{deleteTask: dt}
|
||||
}
|
||||
|
||||
if err := dr.queue.Enqueue(enqueuedTask); err != nil {
|
||||
if err := dr.queue.Enqueue(dt); err != nil {
|
||||
log.Ctx(ctx).Error("Failed to enqueue delete task: " + err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -15,13 +15,9 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type deleteTaskByStreamingService struct {
|
||||
*deleteTask
|
||||
}
|
||||
|
||||
// Execute is a function to delete task by streaming service
|
||||
// we only overwrite the Execute function
|
||||
func (dt *deleteTaskByStreamingService) Execute(ctx context.Context) (err error) {
|
||||
func (dt *deleteTask) Execute(ctx context.Context) (err error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute")
|
||||
defer sp.End()
|
||||
|
||||
|
||||
@ -14,11 +14,11 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
@ -152,19 +152,6 @@ func TestDeleteTask_Execute(t *testing.T) {
|
||||
assert.Error(t, dt.Execute(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("get channel failed", func(t *testing.T) {
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
dt := deleteTask{
|
||||
chMgr: mockMgr,
|
||||
req: &milvuspb.DeleteRequest{
|
||||
Expr: "pk in [1,2]",
|
||||
},
|
||||
}
|
||||
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
assert.Error(t, dt.Execute(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("alloc failed", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@ -189,9 +176,6 @@ func TestDeleteTask_Execute(t *testing.T) {
|
||||
},
|
||||
primaryKeys: pk,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
|
||||
assert.Error(t, dt.Execute(context.Background()))
|
||||
})
|
||||
|
||||
@ -225,9 +209,7 @@ func TestDeleteTask_Execute(t *testing.T) {
|
||||
},
|
||||
primaryKeys: pk,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
|
||||
streaming.ExpectErrorOnce(errors.New("mock error"))
|
||||
assert.Error(t, dt.Execute(context.Background()))
|
||||
})
|
||||
}
|
||||
@ -666,7 +648,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
tsoAllocator := &mockTsoAllocator{}
|
||||
idAllocator := &mockIDAllocatorInterface{}
|
||||
|
||||
queue, err := newTaskScheduler(ctx, tsoAllocator, nil)
|
||||
queue, err := newTaskScheduler(ctx, tsoAllocator)
|
||||
assert.NoError(t, err)
|
||||
queue.Start()
|
||||
defer queue.Close()
|
||||
@ -731,11 +713,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
},
|
||||
plan: plan,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
|
||||
|
||||
streaming.ExpectErrorOnce(errors.New("mock error"))
|
||||
assert.Error(t, dr.Run(context.Background()))
|
||||
assert.Equal(t, int64(0), dr.result.DeleteCnt)
|
||||
})
|
||||
@ -814,10 +793,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
},
|
||||
plan: plan,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
@ -946,8 +922,6 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
},
|
||||
plan: plan,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
@ -971,8 +945,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
server.FinishSend(nil)
|
||||
return client
|
||||
}, nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
|
||||
|
||||
streaming.ExpectErrorOnce(errors.New("mock error"))
|
||||
assert.Error(t, dr.Run(ctx))
|
||||
assert.Equal(t, int64(0), dr.result.DeleteCnt)
|
||||
})
|
||||
@ -1012,8 +986,6 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
},
|
||||
plan: plan,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
@ -1037,8 +1009,6 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
server.FinishSend(nil)
|
||||
return client
|
||||
}, nil)
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
assert.NoError(t, dr.Run(ctx))
|
||||
assert.Equal(t, int64(3), dr.result.DeleteCnt)
|
||||
})
|
||||
@ -1088,8 +1058,6 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
},
|
||||
plan: plan,
|
||||
}
|
||||
stream := msgstream.NewMockMsgStream(t)
|
||||
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
@ -1114,7 +1082,6 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
return client
|
||||
}, nil)
|
||||
|
||||
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
assert.NoError(t, dr.Run(ctx))
|
||||
assert.Equal(t, int64(3), dr.result.DeleteCnt)
|
||||
})
|
||||
|
||||
@ -18,17 +18,11 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
@ -40,7 +34,7 @@ type flushTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *milvuspb.FlushResponse
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
chMgr channelsMgr
|
||||
}
|
||||
|
||||
func (t *flushTask) TraceCtx() context.Context {
|
||||
@ -88,46 +82,6 @@ func (t *flushTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *flushTask) Execute(ctx context.Context) error {
|
||||
coll2Segments := make(map[string]*schemapb.LongArray)
|
||||
flushColl2Segments := make(map[string]*schemapb.LongArray)
|
||||
coll2SealTimes := make(map[string]int64)
|
||||
coll2FlushTs := make(map[string]Timestamp)
|
||||
channelCps := make(map[string]*msgpb.MsgPosition)
|
||||
for _, collName := range t.CollectionNames {
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collName)
|
||||
if err != nil {
|
||||
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
|
||||
}
|
||||
flushReq := &datapb.FlushRequest{
|
||||
Base: commonpbutil.UpdateMsgBase(
|
||||
t.Base,
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Flush),
|
||||
),
|
||||
CollectionID: collID,
|
||||
}
|
||||
resp, err := t.mixCoord.Flush(ctx, flushReq)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return fmt.Errorf("failed to call flush to data coordinator: %s", err.Error())
|
||||
}
|
||||
coll2Segments[collName] = &schemapb.LongArray{Data: resp.GetSegmentIDs()}
|
||||
flushColl2Segments[collName] = &schemapb.LongArray{Data: resp.GetFlushSegmentIDs()}
|
||||
coll2SealTimes[collName] = resp.GetTimeOfSeal()
|
||||
coll2FlushTs[collName] = resp.GetFlushTs()
|
||||
channelCps = resp.GetChannelCps()
|
||||
}
|
||||
t.result = &milvuspb.FlushResponse{
|
||||
Status: merr.Success(),
|
||||
DbName: t.GetDbName(),
|
||||
CollSegIDs: coll2Segments,
|
||||
FlushCollSegIDs: flushColl2Segments,
|
||||
CollSealTimes: coll2SealTimes,
|
||||
CollFlushTs: coll2FlushTs,
|
||||
ChannelCps: channelCps,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *flushTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -18,15 +18,12 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
@ -37,8 +34,7 @@ type flushAllTask struct {
|
||||
ctx context.Context
|
||||
mixCoord types.MixCoordClient
|
||||
result *datapb.FlushAllResponse
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
chMgr channelsMgr
|
||||
}
|
||||
|
||||
func (t *flushAllTask) TraceCtx() context.Context {
|
||||
@ -86,22 +82,6 @@ func (t *flushAllTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *flushAllTask) Execute(ctx context.Context) error {
|
||||
flushAllReq := &datapb.FlushAllRequest{
|
||||
Base: commonpbutil.UpdateMsgBase(
|
||||
t.Base,
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Flush),
|
||||
),
|
||||
}
|
||||
|
||||
resp, err := t.mixCoord.FlushAll(ctx, flushAllReq)
|
||||
if err = merr.CheckRPCCall(resp, err); err != nil {
|
||||
return fmt.Errorf("failed to call flush all to data coordinator: %s", err.Error())
|
||||
}
|
||||
t.result = resp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *flushAllTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -30,12 +30,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
type flushAllTaskbyStreamingService struct {
|
||||
*flushAllTask
|
||||
chMgr channelsMgr
|
||||
}
|
||||
|
||||
func (t *flushAllTaskbyStreamingService) Execute(ctx context.Context) error {
|
||||
func (t *flushAllTask) Execute(ctx context.Context) error {
|
||||
dbNames := make([]string, 0)
|
||||
if t.GetDbName() != "" {
|
||||
dbNames = append(dbNames, t.GetDbName())
|
||||
@ -60,7 +55,7 @@ func (t *flushAllTaskbyStreamingService) Execute(ctx context.Context) error {
|
||||
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections)),
|
||||
DbName: dbName,
|
||||
})
|
||||
if err != nil {
|
||||
if err := merr.CheckRPCCall(showColRsp, err); err != nil {
|
||||
log.Info("flush all task by streaming service failed, show collections failed", zap.String("dbName", dbName), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flushAllTaskbyStreamingService, *mocks.MockMixCoordClient, *msgstream.MockMsgStream, *MockChannelsMgr, context.Context) {
|
||||
func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flushAllTask, *mocks.MockMixCoordClient, *msgstream.MockMsgStream, *MockChannelsMgr, context.Context) {
|
||||
ctx := context.Background()
|
||||
mixCoord := mocks.NewMockMixCoordClient(t)
|
||||
replicateMsgStream := msgstream.NewMockMsgStream(t)
|
||||
@ -51,17 +51,11 @@ func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flu
|
||||
},
|
||||
DbName: dbName,
|
||||
},
|
||||
ctx: ctx,
|
||||
mixCoord: mixCoord,
|
||||
replicateMsgStream: replicateMsgStream,
|
||||
ctx: ctx,
|
||||
mixCoord: mixCoord,
|
||||
chMgr: chMgr,
|
||||
}
|
||||
|
||||
task := &flushAllTaskbyStreamingService{
|
||||
flushAllTask: baseTask,
|
||||
chMgr: chMgr,
|
||||
}
|
||||
|
||||
return task, mixCoord, replicateMsgStream, chMgr, ctx
|
||||
return baseTask, mixCoord, replicateMsgStream, chMgr, ctx
|
||||
}
|
||||
|
||||
func TestFlushAllTask_WithSpecificDB(t *testing.T) {
|
||||
|
||||
@ -18,19 +18,15 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/uniquegenerator"
|
||||
)
|
||||
|
||||
@ -50,9 +46,8 @@ func createTestFlushAllTask(t *testing.T) (*flushAllTask, *mocks.MockMixCoordCli
|
||||
SourceID: 1,
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
mixCoord: mixCoord,
|
||||
replicateMsgStream: replicateMsgStream,
|
||||
ctx: ctx,
|
||||
mixCoord: mixCoord,
|
||||
}
|
||||
|
||||
return task, mixCoord, replicateMsgStream, ctx
|
||||
@ -151,95 +146,6 @@ func TestFlushAllTaskPreExecute(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFlushAllTaskExecuteSuccess(t *testing.T) {
|
||||
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
defer replicateMsgStream.AssertExpectations(t)
|
||||
|
||||
// Setup expectations
|
||||
expectedResp := &datapb.FlushAllResponse{
|
||||
Status: merr.Success(),
|
||||
}
|
||||
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(expectedResp, nil).Once()
|
||||
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedResp, task.result)
|
||||
}
|
||||
|
||||
func TestFlushAllTaskExecuteFlushAllRPCError(t *testing.T) {
|
||||
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
defer replicateMsgStream.AssertExpectations(t)
|
||||
|
||||
// Test RPC call error
|
||||
expectedErr := fmt.Errorf("rpc error")
|
||||
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(nil, expectedErr).Once()
|
||||
|
||||
err := task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to call flush all to data coordinator")
|
||||
}
|
||||
|
||||
func TestFlushAllTaskExecuteFlushAllResponseError(t *testing.T) {
|
||||
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
defer replicateMsgStream.AssertExpectations(t)
|
||||
|
||||
// Test response with error status
|
||||
errorResp := &datapb.FlushAllResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "flush all failed",
|
||||
},
|
||||
}
|
||||
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(errorResp, nil).Once()
|
||||
|
||||
err := task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to call flush all to data coordinator")
|
||||
}
|
||||
|
||||
func TestFlushAllTaskExecuteWithMerCheck(t *testing.T) {
|
||||
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
defer replicateMsgStream.AssertExpectations(t)
|
||||
|
||||
// Test successful execution with merr.CheckRPCCall
|
||||
successResp := &datapb.FlushAllResponse{
|
||||
Status: merr.Success(),
|
||||
}
|
||||
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(successResp, nil).Once()
|
||||
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, successResp, task.result)
|
||||
}
|
||||
|
||||
func TestFlushAllTaskExecuteRequestContent(t *testing.T) {
|
||||
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
defer replicateMsgStream.AssertExpectations(t)
|
||||
|
||||
// Test the content of the FlushAllRequest sent to mixCoord
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(&datapb.FlushAllResponse{Status: merr.Success()}, nil).Once()
|
||||
|
||||
err := task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// The test verifies that Execute method creates the correct request structure internally
|
||||
// The actual request content validation is covered by other tests
|
||||
}
|
||||
|
||||
func TestFlushAllTaskPostExecute(t *testing.T) {
|
||||
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
@ -249,97 +155,6 @@ func TestFlushAllTaskPostExecute(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFlushAllTaskLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mixCoord := mocks.NewMockMixCoordClient(t)
|
||||
replicateMsgStream := msgstream.NewMockMsgStream(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
defer replicateMsgStream.AssertExpectations(t)
|
||||
|
||||
// Test complete task lifecycle
|
||||
|
||||
// 1. OnEnqueue
|
||||
task := &flushAllTask{
|
||||
baseTask: baseTask{},
|
||||
Condition: NewTaskCondition(ctx),
|
||||
FlushAllRequest: &milvuspb.FlushAllRequest{},
|
||||
ctx: ctx,
|
||||
mixCoord: mixCoord,
|
||||
replicateMsgStream: replicateMsgStream,
|
||||
}
|
||||
|
||||
err := task.OnEnqueue()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 2. PreExecute
|
||||
err = task.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 3. Execute
|
||||
expectedResp := &datapb.FlushAllResponse{
|
||||
Status: merr.Success(),
|
||||
}
|
||||
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(expectedResp, nil).Once()
|
||||
|
||||
err = task.Execute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 4. PostExecute
|
||||
err = task.PostExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify task state
|
||||
assert.Equal(t, expectedResp, task.result)
|
||||
}
|
||||
|
||||
func TestFlushAllTaskErrorHandlingInExecute(t *testing.T) {
|
||||
// Test different error scenarios in Execute method
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupMock func(*mocks.MockMixCoordClient)
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "mixCoord FlushAll returns error",
|
||||
setupMock: func(mixCoord *mocks.MockMixCoordClient) {
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(nil, fmt.Errorf("network error")).Once()
|
||||
},
|
||||
expectedError: "failed to call flush all to data coordinator",
|
||||
},
|
||||
{
|
||||
name: "mixCoord FlushAll returns error status",
|
||||
setupMock: func(mixCoord *mocks.MockMixCoordClient) {
|
||||
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
|
||||
Return(&datapb.FlushAllResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_IllegalArgument,
|
||||
Reason: "invalid request",
|
||||
},
|
||||
}, nil).Once()
|
||||
},
|
||||
expectedError: "failed to call flush all to data coordinator",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
|
||||
defer mixCoord.AssertExpectations(t)
|
||||
defer replicateMsgStream.AssertExpectations(t)
|
||||
|
||||
tc.setupMock(mixCoord)
|
||||
|
||||
err := task.Execute(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushAllTaskImplementsTaskInterface(t *testing.T) {
|
||||
// Verify that flushAllTask implements the task interface
|
||||
var _ task = (*flushAllTask)(nil)
|
||||
|
||||
@ -35,12 +35,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
||||
)
|
||||
|
||||
type flushTaskByStreamingService struct {
|
||||
*flushTask
|
||||
chMgr channelsMgr
|
||||
}
|
||||
|
||||
func (t *flushTaskByStreamingService) Execute(ctx context.Context) error {
|
||||
func (t *flushTask) Execute(ctx context.Context) error {
|
||||
coll2Segments := make(map[string]*schemapb.LongArray)
|
||||
flushColl2Segments := make(map[string]*schemapb.LongArray)
|
||||
coll2SealTimes := make(map[string]int64)
|
||||
|
||||
@ -32,7 +32,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
@ -64,8 +63,6 @@ type createIndexTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
|
||||
isAutoIndex bool
|
||||
newIndexParams []*commonpb.KeyValuePair
|
||||
newTypeParams []*commonpb.KeyValuePair
|
||||
@ -580,7 +577,6 @@ func (cit *createIndexTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(cit.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, cit.replicateMsgStream, cit.req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -596,8 +592,6 @@ type alterIndexTask struct {
|
||||
mixCoord types.MixCoordClient
|
||||
result *commonpb.Status
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
|
||||
collectionID UniqueID
|
||||
}
|
||||
|
||||
@ -711,7 +705,6 @@ func (t *alterIndexTask) Execute(ctx context.Context) error {
|
||||
if err = merr.CheckRPCCall(t.result, err); err != nil {
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -978,8 +971,6 @@ type dropIndexTask struct {
|
||||
result *commonpb.Status
|
||||
|
||||
collectionID UniqueID
|
||||
|
||||
replicateMsgStream msgstream.MsgStream
|
||||
}
|
||||
|
||||
func (dit *dropIndexTask) TraceCtx() context.Context {
|
||||
@ -1073,7 +1064,6 @@ func (dit *dropIndexTask) Execute(ctx context.Context) error {
|
||||
ctxLog.Warn("drop index failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
SendReplicateMessagePack(ctx, dit.replicateMsgStream, dit.DropIndexRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
||||
@ -46,6 +47,7 @@ import (
|
||||
func TestMain(m *testing.M) {
|
||||
paramtable.Init()
|
||||
gin.SetMode(gin.TestMode)
|
||||
streaming.SetupNoopWALForTest()
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
@ -2,7 +2,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
@ -15,7 +14,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
@ -32,9 +30,7 @@ type insertTask struct {
|
||||
|
||||
result *milvuspb.MutationResult
|
||||
idAllocator *allocator.IDAllocator
|
||||
segIDAssigner *segIDAssigner
|
||||
chMgr channelsMgr
|
||||
chTicker channelsTimeTicker
|
||||
vChannels []vChan
|
||||
pChannels []pChan
|
||||
schema *schemapb.CollectionSchema
|
||||
@ -292,76 +288,6 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) Execute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute")
|
||||
defer sp.End()
|
||||
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID()))
|
||||
|
||||
collectionName := it.insertMsg.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.GetDbName(), collectionName)
|
||||
log := log.Ctx(ctx)
|
||||
if err != nil {
|
||||
log.Warn("fail to get collection id", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
it.insertMsg.CollectionID = collID
|
||||
|
||||
getCacheDur := tr.RecordSpan()
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(ctx, collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
getMsgStreamDur := tr.RecordSpan()
|
||||
|
||||
channelNames, err := it.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
log.Warn("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err))
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("send insert request to virtual channels",
|
||||
zap.String("partition", it.insertMsg.GetPartitionName()),
|
||||
zap.Int64("collectionID", collID),
|
||||
zap.Strings("virtual_channels", channelNames),
|
||||
zap.Int64("task_id", it.ID()),
|
||||
zap.Duration("get cache duration", getCacheDur),
|
||||
zap.Duration("get msgStream duration", getMsgStreamDur))
|
||||
|
||||
// assign segmentID for insert data and repack data by segmentID
|
||||
var msgPack *msgstream.MsgPack
|
||||
if it.partitionKeys == nil {
|
||||
msgPack, err = repackInsertData(it.TraceCtx(), channelNames, it.insertMsg, it.result, it.idAllocator, it.segIDAssigner)
|
||||
} else {
|
||||
msgPack, err = repackInsertDataWithPartitionKey(it.TraceCtx(), channelNames, it.partitionKeys, it.insertMsg, it.result, it.idAllocator, it.segIDAssigner)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warn("assign segmentID and repack insert data failed", zap.Error(err))
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
assignSegmentIDDur := tr.RecordSpan()
|
||||
|
||||
log.Debug("assign segmentID for insert data success",
|
||||
zap.Duration("assign segmentID duration", assignSegmentIDDur))
|
||||
err = stream.Produce(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Warn("fail to produce insert msg", zap.Error(err))
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
sendMsgDur := tr.RecordSpan()
|
||||
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
|
||||
totalExecDur := tr.ElapseSpan()
|
||||
log.Debug("Proxy Insert Execute done",
|
||||
zap.String("collectionName", collectionName),
|
||||
zap.Duration("send message duration", sendMsgDur),
|
||||
zap.Duration("execute duration", totalExecDur))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *insertTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -18,12 +18,8 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type insertTaskByStreamingService struct {
|
||||
*insertTask
|
||||
}
|
||||
|
||||
// we only overwrite the Execute function
|
||||
func (it *insertTaskByStreamingService) Execute(ctx context.Context) error {
|
||||
func (it *insertTask) Execute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute")
|
||||
defer sp.End()
|
||||
|
||||
|
||||
@ -29,7 +29,6 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
|
||||
@ -437,22 +436,18 @@ type taskScheduler struct {
|
||||
wg sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
msFactory msgstream.Factory
|
||||
}
|
||||
|
||||
type schedOpt func(*taskScheduler)
|
||||
|
||||
func newTaskScheduler(ctx context.Context,
|
||||
tsoAllocatorIns tsoAllocator,
|
||||
factory msgstream.Factory,
|
||||
opts ...schedOpt,
|
||||
) (*taskScheduler, error) {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
s := &taskScheduler{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
msFactory: factory,
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.ddQueue = newDdTaskQueue(tsoAllocatorIns)
|
||||
s.dmQueue = newDmTaskQueue(tsoAllocatorIns)
|
||||
|
||||
@ -492,9 +492,8 @@ func TestTaskScheduler(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
|
||||
sched, err := newTaskScheduler(ctx, tsoAllocatorIns, factory)
|
||||
sched, err := newTaskScheduler(ctx, tsoAllocatorIns)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, sched)
|
||||
|
||||
@ -572,8 +571,7 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
|
||||
).Return(collectionID, nil)
|
||||
globalMetaCache = cache
|
||||
tsoAllocatorIns := newMockTsoAllocator()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns, factory)
|
||||
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns)
|
||||
assert.NoError(t, err)
|
||||
|
||||
run := func(wg *sync.WaitGroup) {
|
||||
|
||||
@ -3586,7 +3586,7 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
node.tsoAllocator = ×tampAllocator{
|
||||
tso: newMockTimestampAllocatorInterface(),
|
||||
}
|
||||
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
|
||||
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator)
|
||||
assert.NoError(t, err)
|
||||
node.sched = scheduler
|
||||
err = node.sched.Start()
|
||||
|
||||
@ -2161,12 +2161,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
|
||||
assert.NoError(t, err)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -2182,11 +2177,6 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
||||
_ = idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
|
||||
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
t.Run("insert", func(t *testing.T) {
|
||||
hash := testutils.GenerateHashKeys(nb)
|
||||
task := &insertTask{
|
||||
@ -2221,13 +2211,11 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
||||
UpsertCnt: 0,
|
||||
Timestamp: 0,
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
for fieldName, dataType := range fieldName2Types {
|
||||
@ -2403,12 +2391,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, mixc)
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
|
||||
assert.NoError(t, err)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -2424,12 +2407,6 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
||||
_ = idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
|
||||
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Init()
|
||||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
t.Run("insert", func(t *testing.T) {
|
||||
hash := testutils.GenerateHashKeys(nb)
|
||||
task := &insertTask{
|
||||
@ -2464,13 +2441,11 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
||||
UpsertCnt: 0,
|
||||
Timestamp: 0,
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
fieldID := common.StartOfUserFieldID
|
||||
@ -2549,13 +2524,12 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
||||
UpsertCnt: 0,
|
||||
Timestamp: 0,
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
fieldID := common.StartOfUserFieldID
|
||||
@ -3817,12 +3791,7 @@ func TestPartitionKey(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
|
||||
assert.NoError(t, err)
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -3838,12 +3807,6 @@ func TestPartitionKey(t *testing.T) {
|
||||
_ = idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
|
||||
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Init()
|
||||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", collectionName)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames)))
|
||||
@ -3888,13 +3851,11 @@ func TestPartitionKey(t *testing.T) {
|
||||
UpsertCnt: 0,
|
||||
Timestamp: 0,
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
// don't support specify partition name if use partition key
|
||||
@ -3932,10 +3893,9 @@ func TestPartitionKey(t *testing.T) {
|
||||
IdField: nil,
|
||||
},
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
}
|
||||
|
||||
// don't support specify partition name if use partition key
|
||||
@ -4060,12 +4020,8 @@ func TestDefaultPartition(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
|
||||
|
||||
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
|
||||
assert.NoError(t, err)
|
||||
pchans, err := chMgr.getChannels(collectionID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -4081,12 +4037,6 @@ func TestDefaultPartition(t *testing.T) {
|
||||
_ = idAllocator.Start()
|
||||
defer idAllocator.Close()
|
||||
|
||||
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
|
||||
assert.NoError(t, err)
|
||||
segAllocator.Init()
|
||||
_ = segAllocator.Start()
|
||||
defer segAllocator.Close()
|
||||
|
||||
nb := 10
|
||||
fieldID := common.StartOfUserFieldID
|
||||
fieldDatas := make([]*schemapb.FieldData, 0)
|
||||
@ -4127,13 +4077,11 @@ func TestDefaultPartition(t *testing.T) {
|
||||
UpsertCnt: 0,
|
||||
Timestamp: 0,
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
vChannels: nil,
|
||||
pChannels: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
it.insertMsg.PartitionName = ""
|
||||
@ -4167,10 +4115,9 @@ func TestDefaultPartition(t *testing.T) {
|
||||
IdField: nil,
|
||||
},
|
||||
},
|
||||
idAllocator: idAllocator,
|
||||
segIDAssigner: segAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
idAllocator: idAllocator,
|
||||
chMgr: chMgr,
|
||||
chTicker: ticker,
|
||||
}
|
||||
|
||||
ut.req.PartitionName = ""
|
||||
|
||||
@ -17,7 +17,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -55,7 +54,6 @@ type upsertTask struct {
|
||||
rowIDs []int64
|
||||
result *milvuspb.MutationResult
|
||||
idAllocator *allocator.IDAllocator
|
||||
segIDAssigner *segIDAssigner
|
||||
collectionID UniqueID
|
||||
chMgr channelsMgr
|
||||
chTicker channelsTimeTicker
|
||||
@ -445,189 +443,6 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgPack) error {
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", it.ID()))
|
||||
defer tr.Elapse("insert execute done when insertExecute")
|
||||
|
||||
collectionName := it.upsertMsg.InsertMsg.CollectionName
|
||||
collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collectionName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
it.upsertMsg.InsertMsg.CollectionID = collID
|
||||
it.upsertMsg.InsertMsg.BeginTimestamp = it.BeginTs()
|
||||
it.upsertMsg.InsertMsg.EndTimestamp = it.EndTs()
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", collID))
|
||||
getCacheDur := tr.RecordSpan()
|
||||
|
||||
_, err = it.chMgr.getOrCreateDmlStream(ctx, collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
getMsgStreamDur := tr.RecordSpan()
|
||||
channelNames, err := it.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
log.Warn("get vChannels failed when insertExecute",
|
||||
zap.Error(err))
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("send insert request to virtual channels when insertExecute",
|
||||
zap.String("collection", it.req.GetCollectionName()),
|
||||
zap.String("partition", it.req.GetPartitionName()),
|
||||
zap.Int64("collection_id", collID),
|
||||
zap.Strings("virtual_channels", channelNames),
|
||||
zap.Int64("task_id", it.ID()),
|
||||
zap.Duration("get cache duration", getCacheDur),
|
||||
zap.Duration("get msgStream duration", getMsgStreamDur))
|
||||
|
||||
// assign segmentID for insert data and repack data by segmentID
|
||||
var insertMsgPack *msgstream.MsgPack
|
||||
if it.partitionKeys == nil {
|
||||
insertMsgPack, err = repackInsertData(it.TraceCtx(), channelNames, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner)
|
||||
} else {
|
||||
insertMsgPack, err = repackInsertDataWithPartitionKey(it.TraceCtx(), channelNames, it.partitionKeys, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warn("assign segmentID and repack insert data failed when insertExecute",
|
||||
zap.Error(err))
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
assignSegmentIDDur := tr.RecordSpan()
|
||||
log.Debug("assign segmentID for insert data success when insertExecute",
|
||||
zap.String("collectionName", it.req.CollectionName),
|
||||
zap.Duration("assign segmentID duration", assignSegmentIDDur))
|
||||
msgPack.Msgs = append(msgPack.Msgs, insertMsgPack.Msgs...)
|
||||
|
||||
log.Debug("Proxy Insert Execute done when upsert",
|
||||
zap.String("collectionName", collectionName))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgPack) (err error) {
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
|
||||
collID := it.upsertMsg.DeleteMsg.CollectionID
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", collID))
|
||||
// hash primary keys to channels
|
||||
channelNames, err := it.chMgr.getVChannels(collID)
|
||||
if err != nil {
|
||||
log.Warn("get vChannels failed when deleteExecute", zap.Error(err))
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs
|
||||
it.upsertMsg.DeleteMsg.HashValues = typeutil.HashPK2Channels(it.upsertMsg.DeleteMsg.PrimaryKeys, channelNames)
|
||||
|
||||
// repack delete msg by dmChannel
|
||||
result := make(map[uint32]msgstream.TsMsg)
|
||||
collectionName := it.upsertMsg.DeleteMsg.CollectionName
|
||||
collectionID := it.upsertMsg.DeleteMsg.CollectionID
|
||||
partitionID := it.upsertMsg.DeleteMsg.PartitionID
|
||||
partitionName := it.upsertMsg.DeleteMsg.PartitionName
|
||||
proxyID := it.upsertMsg.DeleteMsg.Base.SourceID
|
||||
for index, key := range it.upsertMsg.DeleteMsg.HashValues {
|
||||
ts := it.upsertMsg.DeleteMsg.Timestamps[index]
|
||||
_, ok := result[key]
|
||||
if !ok {
|
||||
msgid, err := it.idAllocator.AllocOne()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to allocate MsgID for delete of upsert")
|
||||
}
|
||||
sliceRequest := &msgpb.DeleteRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
|
||||
commonpbutil.WithTimeStamp(ts),
|
||||
// id of upsertTask were set as ts in scheduler
|
||||
// msgid of delete msg must be set
|
||||
// or it will be seen as duplicated msg in mq
|
||||
commonpbutil.WithMsgID(msgid),
|
||||
commonpbutil.WithSourceID(proxyID),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
PartitionID: partitionID,
|
||||
CollectionName: collectionName,
|
||||
PartitionName: partitionName,
|
||||
PrimaryKeys: &schemapb.IDs{},
|
||||
}
|
||||
deleteMsg := &msgstream.DeleteMsg{
|
||||
BaseMsg: msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
},
|
||||
DeleteRequest: sliceRequest,
|
||||
}
|
||||
result[key] = deleteMsg
|
||||
}
|
||||
curMsg := result[key].(*msgstream.DeleteMsg)
|
||||
curMsg.HashValues = append(curMsg.HashValues, it.upsertMsg.DeleteMsg.HashValues[index])
|
||||
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
|
||||
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
|
||||
curMsg.NumRows++
|
||||
curMsg.ShardName = channelNames[key]
|
||||
}
|
||||
|
||||
// send delete request to log broker
|
||||
deleteMsgPack := &msgstream.MsgPack{
|
||||
BeginTs: it.upsertMsg.DeleteMsg.BeginTs(),
|
||||
EndTs: it.upsertMsg.DeleteMsg.EndTs(),
|
||||
}
|
||||
for _, msg := range result {
|
||||
if msg != nil {
|
||||
deleteMsgPack.Msgs = append(deleteMsgPack.Msgs, msg)
|
||||
}
|
||||
}
|
||||
msgPack.Msgs = append(msgPack.Msgs, deleteMsgPack.Msgs...)
|
||||
|
||||
log.Debug("Proxy Upsert deleteExecute done", zap.Int64("collection_id", collID),
|
||||
zap.Strings("virtual_channels", channelNames), zap.Int64("taskID", it.ID()),
|
||||
zap.Duration("prepare duration", tr.ElapseSpan()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) Execute(ctx context.Context) (err error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
|
||||
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
|
||||
stream, err := it.chMgr.getOrCreateDmlStream(ctx, it.collectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: it.BeginTs(),
|
||||
EndTs: it.EndTs(),
|
||||
}
|
||||
err = it.insertExecute(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Warn("Fail to insertExecute", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
err = it.deleteExecute(ctx, msgPack)
|
||||
if err != nil {
|
||||
log.Warn("Fail to deleteExecute", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
tr.RecordSpan()
|
||||
err = stream.Produce(ctx, msgPack)
|
||||
if err != nil {
|
||||
it.result.Status = merr.Status(err)
|
||||
return err
|
||||
}
|
||||
sendMsgDur := tr.RecordSpan()
|
||||
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
|
||||
totalDur := tr.ElapseSpan()
|
||||
log.Debug("Proxy Upsert Execute done", zap.Int64("taskID", it.ID()),
|
||||
zap.Duration("total duration", totalDur))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (it *upsertTask) PostExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -15,11 +15,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type upsertTaskByStreamingService struct {
|
||||
*upsertTask
|
||||
}
|
||||
|
||||
func (ut *upsertTaskByStreamingService) Execute(ctx context.Context) error {
|
||||
func (ut *upsertTask) Execute(ctx context.Context) error {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute")
|
||||
defer sp.End()
|
||||
log := log.Ctx(ctx).With(zap.String("collectionName", ut.req.CollectionName))
|
||||
@ -46,7 +42,7 @@ func (ut *upsertTaskByStreamingService) Execute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ut *upsertTaskByStreamingService) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) {
|
||||
func (ut *upsertTask) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) {
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", ut.ID()))
|
||||
defer tr.Elapse("insert execute done when insertExecute")
|
||||
|
||||
@ -93,7 +89,7 @@ func (ut *upsertTaskByStreamingService) packInsertMessage(ctx context.Context) (
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
func (it *upsertTaskByStreamingService) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) {
|
||||
func (it *upsertTask) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) {
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
|
||||
collID := it.upsertMsg.DeleteMsg.CollectionID
|
||||
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs
|
||||
|
||||
@ -2286,184 +2286,6 @@ func checkDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
|
||||
return nil
|
||||
}
|
||||
|
||||
func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.MsgStream, request interface{ GetBase() *commonpb.MsgBase }) {
|
||||
if replicateMsgStream == nil || request == nil {
|
||||
log.Ctx(ctx).Warn("replicate msg stream or request is nil", zap.Any("request", request))
|
||||
return
|
||||
}
|
||||
msgBase := request.GetBase()
|
||||
ts := msgBase.GetTimestamp()
|
||||
if msgBase.GetReplicateInfo().GetIsReplicate() {
|
||||
ts = msgBase.GetReplicateInfo().GetMsgTimestamp()
|
||||
}
|
||||
getBaseMsg := func(ctx context.Context, ts uint64) msgstream.BaseMsg {
|
||||
return msgstream.BaseMsg{
|
||||
Ctx: ctx,
|
||||
HashValues: []uint32{0},
|
||||
BeginTimestamp: ts,
|
||||
EndTimestamp: ts,
|
||||
}
|
||||
}
|
||||
|
||||
var tsMsg msgstream.TsMsg
|
||||
switch r := request.(type) {
|
||||
case *milvuspb.AlterCollectionRequest:
|
||||
tsMsg = &msgstream.AlterCollectionMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
AlterCollectionRequest: r,
|
||||
}
|
||||
case *milvuspb.AlterCollectionFieldRequest:
|
||||
tsMsg = &msgstream.AlterCollectionFieldMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
AlterCollectionFieldRequest: r,
|
||||
}
|
||||
case *milvuspb.RenameCollectionRequest:
|
||||
tsMsg = &msgstream.RenameCollectionMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
RenameCollectionRequest: r,
|
||||
}
|
||||
case *milvuspb.CreateDatabaseRequest:
|
||||
tsMsg = &msgstream.CreateDatabaseMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
CreateDatabaseRequest: r,
|
||||
}
|
||||
case *milvuspb.DropDatabaseRequest:
|
||||
tsMsg = &msgstream.DropDatabaseMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
DropDatabaseRequest: r,
|
||||
}
|
||||
case *milvuspb.AlterDatabaseRequest:
|
||||
tsMsg = &msgstream.AlterDatabaseMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
AlterDatabaseRequest: r,
|
||||
}
|
||||
case *milvuspb.FlushRequest:
|
||||
tsMsg = &msgstream.FlushMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
FlushRequest: r,
|
||||
}
|
||||
case *milvuspb.LoadCollectionRequest:
|
||||
tsMsg = &msgstream.LoadCollectionMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
LoadCollectionRequest: r,
|
||||
}
|
||||
case *milvuspb.ReleaseCollectionRequest:
|
||||
tsMsg = &msgstream.ReleaseCollectionMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
ReleaseCollectionRequest: r,
|
||||
}
|
||||
case *milvuspb.CreateIndexRequest:
|
||||
tsMsg = &msgstream.CreateIndexMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
CreateIndexRequest: r,
|
||||
}
|
||||
case *milvuspb.DropIndexRequest:
|
||||
tsMsg = &msgstream.DropIndexMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
DropIndexRequest: r,
|
||||
}
|
||||
case *milvuspb.LoadPartitionsRequest:
|
||||
tsMsg = &msgstream.LoadPartitionsMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
LoadPartitionsRequest: r,
|
||||
}
|
||||
case *milvuspb.ReleasePartitionsRequest:
|
||||
tsMsg = &msgstream.ReleasePartitionsMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
ReleasePartitionsRequest: r,
|
||||
}
|
||||
case *milvuspb.AlterIndexRequest:
|
||||
tsMsg = &msgstream.AlterIndexMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
AlterIndexRequest: r,
|
||||
}
|
||||
case *milvuspb.CreateCredentialRequest:
|
||||
tsMsg = &msgstream.CreateUserMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
CreateCredentialRequest: r,
|
||||
}
|
||||
case *milvuspb.UpdateCredentialRequest:
|
||||
tsMsg = &msgstream.UpdateUserMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
UpdateCredentialRequest: r,
|
||||
}
|
||||
case *milvuspb.DeleteCredentialRequest:
|
||||
tsMsg = &msgstream.DeleteUserMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
DeleteCredentialRequest: r,
|
||||
}
|
||||
case *milvuspb.CreateRoleRequest:
|
||||
tsMsg = &msgstream.CreateRoleMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
CreateRoleRequest: r,
|
||||
}
|
||||
case *milvuspb.DropRoleRequest:
|
||||
tsMsg = &msgstream.DropRoleMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
DropRoleRequest: r,
|
||||
}
|
||||
case *milvuspb.OperateUserRoleRequest:
|
||||
tsMsg = &msgstream.OperateUserRoleMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
OperateUserRoleRequest: r,
|
||||
}
|
||||
case *milvuspb.OperatePrivilegeRequest:
|
||||
tsMsg = &msgstream.OperatePrivilegeMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
OperatePrivilegeRequest: r,
|
||||
}
|
||||
case *milvuspb.OperatePrivilegeV2Request:
|
||||
tsMsg = &msgstream.OperatePrivilegeV2Msg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
OperatePrivilegeV2Request: r,
|
||||
}
|
||||
case *milvuspb.CreatePrivilegeGroupRequest:
|
||||
tsMsg = &msgstream.CreatePrivilegeGroupMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
CreatePrivilegeGroupRequest: r,
|
||||
}
|
||||
case *milvuspb.DropPrivilegeGroupRequest:
|
||||
tsMsg = &msgstream.DropPrivilegeGroupMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
DropPrivilegeGroupRequest: r,
|
||||
}
|
||||
case *milvuspb.OperatePrivilegeGroupRequest:
|
||||
tsMsg = &msgstream.OperatePrivilegeGroupMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
OperatePrivilegeGroupRequest: r,
|
||||
}
|
||||
case *milvuspb.CreateAliasRequest:
|
||||
tsMsg = &msgstream.CreateAliasMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
CreateAliasRequest: r,
|
||||
}
|
||||
case *milvuspb.DropAliasRequest:
|
||||
tsMsg = &msgstream.DropAliasMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
DropAliasRequest: r,
|
||||
}
|
||||
case *milvuspb.AlterAliasRequest:
|
||||
tsMsg = &msgstream.AlterAliasMsg{
|
||||
BaseMsg: getBaseMsg(ctx, ts),
|
||||
AlterAliasRequest: r,
|
||||
}
|
||||
default:
|
||||
log.Warn("unknown request", zap.Any("request", request))
|
||||
return
|
||||
}
|
||||
msgPack := &msgstream.MsgPack{
|
||||
BeginTs: ts,
|
||||
EndTs: ts,
|
||||
Msgs: []msgstream.TsMsg{tsMsg},
|
||||
}
|
||||
msgErr := replicateMsgStream.Produce(ctx, msgPack)
|
||||
// ignore the error if the msg stream failed to produce the msg,
|
||||
// because it can be manually fixed in this error
|
||||
if msgErr != nil {
|
||||
log.Warn("send replicate msg failed", zap.Any("pack", msgPack), zap.Error(msgErr))
|
||||
}
|
||||
}
|
||||
|
||||
func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) {
|
||||
if globalMetaCache != nil {
|
||||
return globalMetaCache.GetCollectionSchema(ctx, dbName, colName)
|
||||
|
||||
@ -2302,55 +2302,6 @@ func Test_validateMaxCapacityPerRow(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendReplicateMessagePack(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockStream := msgstream.NewMockMsgStream(t)
|
||||
|
||||
t.Run("empty case", func(t *testing.T) {
|
||||
SendReplicateMessagePack(ctx, nil, nil)
|
||||
})
|
||||
|
||||
t.Run("produce fail", func(t *testing.T) {
|
||||
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("produce error")).Once()
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{
|
||||
Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{
|
||||
IsReplicate: true,
|
||||
MsgTimestamp: 100,
|
||||
}},
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unknown request", func(t *testing.T) {
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ListDatabasesRequest{})
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterCollectionRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterCollectionFieldRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.RenameCollectionRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.FlushRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadCollectionRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleaseCollectionRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateIndexRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropIndexRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadPartitionsRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleasePartitionsRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateCredentialRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DeleteCredentialRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateRoleRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropRoleRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.OperateUserRoleRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.OperatePrivilegeRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateAliasRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropAliasRequest{})
|
||||
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterAliasRequest{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppendUserInfoForRPC(t *testing.T) {
|
||||
ctx := GetContext(context.Background(), "root:123456")
|
||||
ctx = AppendUserInfoForRPC(ctx)
|
||||
|
||||
@ -48,12 +48,10 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/internal/util/searchutil/optimizers"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
@ -136,8 +134,6 @@ type shardDelegator struct {
|
||||
// stream delete buffer
|
||||
deleteMut sync.RWMutex
|
||||
deleteBuffer deletebuffer.DeleteBuffer[*deletebuffer.Item]
|
||||
// dispatcherClient msgdispatcher.Client
|
||||
factory msgstream.Factory
|
||||
|
||||
sf conc.Singleflight[struct{}]
|
||||
loader segments.Loader
|
||||
@ -931,7 +927,7 @@ func (sd *shardDelegator) speedupGuranteeTS(
|
||||
// when 1. streaming service is disable,
|
||||
// 2. consistency level is not strong,
|
||||
// 3. cannot speed iterator, because current client of milvus doesn't support shard level mvcc.
|
||||
if !streamingutil.IsStreamingServiceEnabled() || isIterator || cl != commonpb.ConsistencyLevel_Strong || mvccTS != 0 {
|
||||
if isIterator || cl != commonpb.ConsistencyLevel_Strong || mvccTS != 0 {
|
||||
return guaranteeTS
|
||||
}
|
||||
// use the mvcc timestamp of the wal as the guarantee timestamp to make fast strong consistency search.
|
||||
@ -1145,8 +1141,7 @@ func (sd *shardDelegator) loadPartitionStats(ctx context.Context, partStatsVersi
|
||||
|
||||
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
|
||||
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
|
||||
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader,
|
||||
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
|
||||
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
|
||||
queryView *channelQueryView,
|
||||
) (ShardDelegator, error) {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
|
||||
@ -1184,7 +1179,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
||||
pkOracle: pkoracle.NewPkOracle(),
|
||||
latestTsafe: atomic.NewUint64(startTs),
|
||||
loader: loader,
|
||||
factory: factory,
|
||||
queryHook: queryHook,
|
||||
chunkManager: chunkManager,
|
||||
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),
|
||||
|
||||
@ -19,7 +19,6 @@ package delegator
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
@ -40,8 +39,6 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
mqcommon "github.com/milvus-io/milvus/pkg/v2/mq/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
@ -771,120 +768,6 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) {
|
||||
stream, err := sd.factory.NewTtMsgStream(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer stream.Close()
|
||||
vchannelName := position.ChannelName
|
||||
pChannelName := funcutil.ToPhysicalChannel(vchannelName)
|
||||
position.ChannelName = pChannelName
|
||||
|
||||
ts, _ := tsoutil.ParseTS(position.Timestamp)
|
||||
|
||||
// Random the subname in case we trying to load same delta at the same time
|
||||
subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", paramtable.GetNodeID(), sd.collectionID, rand.Int())
|
||||
log.Info("from dml check point load delete", zap.Any("position", position), zap.String("vChannel", vchannelName), zap.String("subName", subName), zap.Time("positionTs", ts))
|
||||
err = stream.AsConsumer(context.TODO(), []string{pChannelName}, subName, mqcommon.SubscriptionPositionUnknown)
|
||||
if err != nil {
|
||||
return nil, stream.Close, err
|
||||
}
|
||||
|
||||
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}, false)
|
||||
if err != nil {
|
||||
return nil, stream.Close, err
|
||||
}
|
||||
|
||||
dispatcher := msgstream.NewSimpleMsgDispatcher(stream, func(pm msgstream.ConsumeMsg) bool {
|
||||
if pm.GetType() != commonpb.MsgType_Delete || pm.GetVChannel() != vchannelName {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return dispatcher.Chan(), dispatcher.Close, nil
|
||||
}
|
||||
|
||||
// Only used in test.
|
||||
func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) {
|
||||
log := sd.getLogger(ctx).With(
|
||||
zap.String("channel", position.ChannelName),
|
||||
zap.Int64("segmentID", candidate.ID()),
|
||||
)
|
||||
pChannelName := funcutil.ToPhysicalChannel(position.ChannelName)
|
||||
|
||||
var ch <-chan *msgstream.MsgPack
|
||||
var closer func()
|
||||
var err error
|
||||
ch, closer, err = sd.createStreamFromMsgStream(ctx, position)
|
||||
if closer != nil {
|
||||
defer closer()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result := &storage.DeleteData{}
|
||||
hasMore := true
|
||||
for hasMore {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug("read delta msg from seek position done", zap.Error(ctx.Err()))
|
||||
return nil, ctx.Err()
|
||||
case msgPack, ok := <-ch:
|
||||
if !ok {
|
||||
err = fmt.Errorf("stream channel closed, pChannelName=%v, msgID=%v", pChannelName, position.GetMsgID())
|
||||
log.Warn("fail to read delta msg",
|
||||
zap.String("pChannelName", pChannelName),
|
||||
zap.Binary("msgID", position.GetMsgID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msgPack == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, tsMsg := range msgPack.Msgs {
|
||||
if tsMsg.Type() == commonpb.MsgType_Delete {
|
||||
dmsg := tsMsg.(*msgstream.DeleteMsg)
|
||||
if dmsg.CollectionID != sd.collectionID || (dmsg.GetPartitionID() != common.AllPartitionsID && dmsg.GetPartitionID() != candidate.Partition()) {
|
||||
continue
|
||||
}
|
||||
|
||||
pks := storage.ParseIDs2PrimaryKeys(dmsg.GetPrimaryKeys())
|
||||
batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt()
|
||||
for idx := 0; idx < len(pks); idx += batchSize {
|
||||
endIdx := idx + batchSize
|
||||
if endIdx > len(pks) {
|
||||
endIdx = len(pks)
|
||||
}
|
||||
|
||||
lc := storage.NewBatchLocationsCache(pks[idx:endIdx])
|
||||
hits := candidate.BatchPkExist(lc)
|
||||
for i, hit := range hits {
|
||||
if hit {
|
||||
result.Pks = append(result.Pks, pks[idx+i])
|
||||
result.Tss = append(result.Tss, dmsg.Timestamps[idx+i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reach safe ts
|
||||
if safeTs <= msgPack.EndPositions[0].GetTimestamp() {
|
||||
hasMore = false
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Info("successfully read delete from stream ", zap.Duration("time spent", time.Since(start)))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ReleaseSegments releases segments local or remotely depending on the target node.
|
||||
func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
|
||||
log := sd.getLogger(ctx)
|
||||
|
||||
@ -192,11 +192,7 @@ func (s *DelegatorDataSuite) genCollectionWithFunction() {
|
||||
}},
|
||||
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
|
||||
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.NoError(err)
|
||||
s.delegator = delegator.(*shardDelegator)
|
||||
}
|
||||
@ -214,11 +210,7 @@ func (s *DelegatorDataSuite) SetupTest() {
|
||||
s.rootPath = s.Suite.T().Name()
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
sd, ok := delegator.(*shardDelegator)
|
||||
s.Require().True(ok)
|
||||
@ -806,11 +798,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||
s.workerManager,
|
||||
s.manager,
|
||||
s.loader,
|
||||
&msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.NoError(err)
|
||||
|
||||
growing0 := segments.NewMockSegment(s.T())
|
||||
@ -1524,39 +1512,6 @@ func (s *DelegatorDataSuite) TestLevel0Deletions() {
|
||||
s.Empty(pks)
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
s.mq.EXPECT().Close()
|
||||
ch := make(chan *msgstream.ConsumeMsgPack, 10)
|
||||
s.mq.EXPECT().Chan().Return(ch)
|
||||
|
||||
oracle := pkoracle.NewBloomFilterSet(1, 1, commonpb.SegmentState_Sealed)
|
||||
oracle.UpdateBloomFilter([]storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)})
|
||||
|
||||
baseMsg := &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete}
|
||||
|
||||
datas := []*msgstream.MsgPack{
|
||||
{EndTs: 10, EndPositions: []*msgpb.MsgPosition{{Timestamp: 10}}, Msgs: []msgstream.TsMsg{
|
||||
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: 1, PrimaryKeys: storage.ParseInt64s2IDs(1), Timestamps: []uint64{1}}},
|
||||
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: -1, PrimaryKeys: storage.ParseInt64s2IDs(2), Timestamps: []uint64{5}}},
|
||||
// invalid msg because partition wrong
|
||||
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: 2, PrimaryKeys: storage.ParseInt64s2IDs(1), Timestamps: []uint64{10}}},
|
||||
}},
|
||||
}
|
||||
|
||||
for _, data := range datas {
|
||||
ch <- msgstream.BuildConsumeMsgPack(data)
|
||||
}
|
||||
|
||||
result, err := s.delegator.readDeleteFromMsgstream(ctx, &msgpb.MsgPosition{Timestamp: 0}, 10, oracle)
|
||||
s.NoError(err)
|
||||
s.Equal(2, len(result.Pks))
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestDelegatorData_ExcludeSegments() {
|
||||
s.delegator.AddExcludedSegments(map[int64]uint64{
|
||||
1: 3,
|
||||
|
||||
@ -19,6 +19,7 @@ package delegator
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -33,6 +34,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
@ -51,6 +53,12 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
streaming.SetupNoopWALForTest()
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
type DelegatorSuite struct {
|
||||
suite.Suite
|
||||
|
||||
@ -164,11 +172,7 @@ func (s *DelegatorSuite) SetupTest() {
|
||||
|
||||
var err error
|
||||
// s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader)
|
||||
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
@ -203,11 +207,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
|
||||
}},
|
||||
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
|
||||
|
||||
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
@ -246,11 +246,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
|
||||
}},
|
||||
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
|
||||
|
||||
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
@ -1410,11 +1406,7 @@ func (s *DelegatorSuite) TestUpdateSchema() {
|
||||
func (s *DelegatorSuite) ResetDelegator() {
|
||||
var err error
|
||||
s.delegator.Close()
|
||||
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
||||
@ -151,11 +151,7 @@ func (s *StreamingForwardSuite) SetupTest() {
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
|
||||
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
|
||||
sd, ok := delegator.(*shardDelegator)
|
||||
@ -394,11 +390,7 @@ func (s *GrowingMergeL0Suite) SetupTest() {
|
||||
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
|
||||
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
|
||||
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
|
||||
sd, ok := delegator.(*shardDelegator)
|
||||
|
||||
@ -65,7 +65,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/searchutil/scheduler"
|
||||
"github.com/milvus-io/milvus/internal/util/segcore"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/config"
|
||||
@ -528,11 +527,7 @@ func (node *QueryNode) Init() error {
|
||||
node.manager = segments.NewManager()
|
||||
node.loader = segments.NewLoader(node.ctx, node.manager, node.chunkManager)
|
||||
node.manager.SetLoader(node.loader)
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
node.dispClient = msgdispatcher.NewClientWithIncludeSkipWhenSplit(streaming.NewDelegatorMsgstreamFactory(), typeutil.QueryNodeRole, node.GetNodeID())
|
||||
} else {
|
||||
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, node.GetNodeID())
|
||||
}
|
||||
node.dispClient = msgdispatcher.NewClientWithIncludeSkipWhenSplit(streaming.NewDelegatorMsgstreamFactory(), typeutil.QueryNodeRole, node.GetNodeID())
|
||||
// init pipeline manager
|
||||
node.pipelineManager = pipeline.NewManager(node.manager, node.dispClient, node.delegators)
|
||||
|
||||
|
||||
@ -267,7 +267,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
node.clusterManager,
|
||||
node.manager,
|
||||
node.loader,
|
||||
node.factory,
|
||||
channel.GetSeekPosition().GetTimestamp(),
|
||||
node.queryHook,
|
||||
node.chunkManager,
|
||||
|
||||
@ -49,12 +49,12 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
@ -68,7 +68,6 @@ import (
|
||||
type ServiceSuite struct {
|
||||
suite.Suite
|
||||
// Data
|
||||
msgChan chan *msgstream.ConsumeMsgPack
|
||||
collectionID int64
|
||||
collectionName string
|
||||
schema *schemapb.CollectionSchema
|
||||
@ -92,8 +91,7 @@ type ServiceSuite struct {
|
||||
chunkManagerFactory *storage.ChunkManagerFactory
|
||||
|
||||
// Mock
|
||||
factory *dependency.MockFactory
|
||||
msgStream *msgstream.MockMsgStream
|
||||
factory *dependency.MockFactory
|
||||
}
|
||||
|
||||
func (suite *ServiceSuite) SetupSuite() {
|
||||
@ -129,7 +127,6 @@ func (suite *ServiceSuite) SetupTest() {
|
||||
ctx := context.Background()
|
||||
// init mock
|
||||
suite.factory = dependency.NewMockFactory(suite.T())
|
||||
suite.msgStream = msgstream.NewMockMsgStream(suite.T())
|
||||
// TODO:: cpp chunk manager not support local chunk manager
|
||||
paramtable.Get().Save(paramtable.Get().LocalStorageCfg.Path.Key, suite.T().TempDir())
|
||||
// suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test"))
|
||||
@ -315,13 +312,6 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
|
||||
IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
|
||||
}
|
||||
|
||||
// mocks
|
||||
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
|
||||
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
|
||||
suite.msgStream.EXPECT().Close().Maybe()
|
||||
|
||||
// watchDmChannels
|
||||
status, err := suite.node.WatchDmChannels(ctx, req)
|
||||
suite.NoError(err)
|
||||
@ -367,13 +357,6 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
|
||||
IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
|
||||
}
|
||||
|
||||
// mocks
|
||||
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
|
||||
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
|
||||
suite.msgStream.EXPECT().Close().Maybe()
|
||||
|
||||
// watchDmChannels
|
||||
status, err := suite.node.WatchDmChannels(ctx, req)
|
||||
suite.NoError(err)
|
||||
@ -2437,6 +2420,13 @@ func TestQueryNodeService(t *testing.T) {
|
||||
local.EXPECT().GetLatestMVCCTimestampIfLocal(mock.Anything, mock.Anything).Return(0, nil).Maybe()
|
||||
local.EXPECT().GetMetricsIfLocal(mock.Anything).Return(&types.StreamingNodeMetrics{}, nil).Maybe()
|
||||
wal.EXPECT().Local().Return(local).Maybe()
|
||||
wal.EXPECT().WALName().Return(rmq.WALName).Maybe()
|
||||
scanner := mock_streaming.NewMockScanner(t)
|
||||
scanner.EXPECT().Done().Return(make(chan struct{})).Maybe()
|
||||
scanner.EXPECT().Error().Return(nil).Maybe()
|
||||
scanner.EXPECT().Close().Return().Maybe()
|
||||
wal.EXPECT().Read(mock.Anything, mock.Anything).Return(scanner).Maybe()
|
||||
|
||||
streaming.SetWALForTest(wal)
|
||||
defer streaming.RecoverWALForTest()
|
||||
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
package rootcoord
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
)
|
||||
|
||||
var _ task = (*broadcastTask)(nil)
|
||||
|
||||
// BroadcastTask is used to implement the broadcast operation based on the msgstream
|
||||
// by using the streaming service interface.
|
||||
// msgstream will be deprecated since 2.6.0 with streaming service, so those code will be removed in the future version.
|
||||
type broadcastTask struct {
|
||||
baseTask
|
||||
msgs []message.MutableMessage // The message wait for broadcast
|
||||
walName string
|
||||
resultFuture *syncutil.Future[types.AppendResponses]
|
||||
}
|
||||
|
||||
func (b *broadcastTask) Execute(ctx context.Context) error {
|
||||
result := types.NewAppendResponseN(len(b.msgs))
|
||||
defer func() {
|
||||
b.resultFuture.Set(result)
|
||||
}()
|
||||
|
||||
for idx, msg := range b.msgs {
|
||||
tsMsg, err := adaptor.NewMsgPackFromMutableMessageV1(msg)
|
||||
tsMsg.SetTs(b.ts) // overwrite the ts.
|
||||
if err != nil {
|
||||
result.FillResponseAtIdx(types.AppendResponse{Error: err}, idx)
|
||||
return err
|
||||
}
|
||||
pchannel := funcutil.ToPhysicalChannel(msg.VChannel())
|
||||
msgID, err := b.core.chanTimeTick.broadcastMarkDmlChannels([]string{pchannel}, &msgstream.MsgPack{
|
||||
BeginTs: b.ts,
|
||||
EndTs: b.ts,
|
||||
Msgs: []msgstream.TsMsg{tsMsg},
|
||||
})
|
||||
if err != nil {
|
||||
result.FillResponseAtIdx(types.AppendResponse{Error: err}, idx)
|
||||
continue
|
||||
}
|
||||
result.FillResponseAtIdx(types.AppendResponse{
|
||||
AppendResult: &types.AppendResult{
|
||||
MessageID: adaptor.MustGetMessageIDFromMQWrapperIDBytes(b.walName, msgID[pchannel]),
|
||||
TimeTick: b.ts,
|
||||
},
|
||||
}, idx)
|
||||
}
|
||||
return result.UnwrapFirstError()
|
||||
}
|
||||
|
||||
func newMsgStreamAppendOperator(c *Core) *msgstreamAppendOperator {
|
||||
return &msgstreamAppendOperator{
|
||||
core: c,
|
||||
walName: util.MustSelectWALName(),
|
||||
}
|
||||
}
|
||||
|
||||
// msgstreamAppendOperator the code of streamingcoord to make broadcast available on the legacy msgstream.
|
||||
// Because msgstream is bound to the rootcoord task, so we transfer each broadcast operation into a ddl task.
|
||||
// to make sure the timetick rule.
|
||||
// The Msgstream will be deprecated since 2.6.0, so we make a single module to hold it.
|
||||
type msgstreamAppendOperator struct {
|
||||
core *Core
|
||||
walName string
|
||||
}
|
||||
|
||||
// AppendMessages implements the AppendOperator interface for broadcaster service at streaming service.
|
||||
func (m *msgstreamAppendOperator) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
|
||||
t := &broadcastTask{
|
||||
baseTask: newBaseTask(ctx, m.core),
|
||||
msgs: msgs,
|
||||
walName: m.walName,
|
||||
resultFuture: syncutil.NewFuture[types.AppendResponses](),
|
||||
}
|
||||
|
||||
if err := m.core.scheduler.AddTask(t); err != nil {
|
||||
resp := types.NewAppendResponseN(len(msgs))
|
||||
resp.FillAllError(err)
|
||||
return resp
|
||||
}
|
||||
|
||||
result, err := t.resultFuture.GetWithContext(ctx)
|
||||
if err != nil {
|
||||
resp := types.NewAppendResponseN(len(msgs))
|
||||
resp.FillAllError(err)
|
||||
return resp
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -44,7 +44,6 @@ import (
|
||||
kvmetastore "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
|
||||
"github.com/milvus-io/milvus/internal/metastore/model"
|
||||
streamingcoord "github.com/milvus-io/milvus/internal/streamingcoord/server"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
|
||||
tso2 "github.com/milvus-io/milvus/internal/tso"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
@ -694,11 +693,6 @@ func (c *Core) startInternal() error {
|
||||
c.UpdateStateCode(commonpb.StateCode_Healthy)
|
||||
sessionutil.SaveServerInfo(typeutil.MixCoordRole, c.session.GetServerID())
|
||||
log.Info("rootcoord startup successfully")
|
||||
|
||||
// regster the core as a appendoperator for broadcast service.
|
||||
// TODO: should be removed at 2.6.0.
|
||||
// Add the wal accesser to the broadcaster registry for making broadcast operation.
|
||||
registry.Register(registry.AppendOperatorTypeMsgstream, newMsgStreamAppendOperator(c))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/client/assignment"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker"
|
||||
streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
|
||||
@ -76,11 +75,8 @@ func NewClient(etcdCli *clientv3.Client) Client {
|
||||
dialOptions...,
|
||||
)
|
||||
})
|
||||
var assignmentServiceImpl *assignment.AssignmentServiceImpl
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient)
|
||||
assignmentServiceImpl = assignment.NewAssignmentService(assignmentService)
|
||||
}
|
||||
assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient)
|
||||
assignmentServiceImpl := assignment.NewAssignmentService(assignmentService)
|
||||
broadcastService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordBroadcastServiceClient)
|
||||
return &clientImpl{
|
||||
conn: conn,
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
package broadcaster
|
||||
|
||||
import (
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
)
|
||||
|
||||
// NewAppendOperator creates an append operator to handle the incoming messages for broadcaster.
|
||||
func NewAppendOperator() AppendOperator {
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
return streaming.WAL()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -3,7 +3,6 @@ package broadcaster
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
)
|
||||
@ -19,4 +18,9 @@ type Broadcaster interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
type AppendOperator = registry.AppendOperator
|
||||
// AppendOperator is used to append messages, there's only two implement of this interface:
|
||||
// 1. streaming.WAL()
|
||||
// 2. old msgstream interface [deprecated]
|
||||
type AppendOperator interface {
|
||||
AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses
|
||||
}
|
||||
|
||||
@ -22,7 +22,6 @@ import (
|
||||
|
||||
func RecoverBroadcaster(
|
||||
ctx context.Context,
|
||||
appendOperator *syncutil.Future[AppendOperator],
|
||||
) (Broadcaster, error) {
|
||||
tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx)
|
||||
if err != nil {
|
||||
@ -38,7 +37,6 @@ func RecoverBroadcaster(
|
||||
backoffChan: make(chan *pendingBroadcastTask),
|
||||
pendingChan: make(chan *pendingBroadcastTask),
|
||||
workerChan: make(chan *pendingBroadcastTask),
|
||||
appendOperator: appendOperator,
|
||||
}
|
||||
go b.execute()
|
||||
return b, nil
|
||||
@ -54,7 +52,6 @@ type broadcasterImpl struct {
|
||||
pendingChan chan *pendingBroadcastTask
|
||||
backoffChan chan *pendingBroadcastTask
|
||||
workerChan chan *pendingBroadcastTask
|
||||
appendOperator *syncutil.Future[AppendOperator] // TODO: we can remove those lazy future in 2.6.0, by remove the msgstream broadcaster.
|
||||
}
|
||||
|
||||
// Broadcast broadcasts the message to all channels.
|
||||
@ -140,14 +137,6 @@ func (b *broadcasterImpl) execute() {
|
||||
b.Logger().Info("broadcaster execute exit")
|
||||
}()
|
||||
|
||||
// Wait for appendOperator ready
|
||||
appendOperator, err := b.appendOperator.GetWithContext(b.backgroundTaskNotifier.Context())
|
||||
if err != nil {
|
||||
b.Logger().Info("broadcaster is closed before appendOperator ready")
|
||||
return
|
||||
}
|
||||
b.Logger().Info("broadcaster appendOperator ready, begin to start workers and dispatch")
|
||||
|
||||
// Start n workers to handle the broadcast task.
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < workers; i++ {
|
||||
@ -156,7 +145,7 @@ func (b *broadcasterImpl) execute() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
b.worker(i, appendOperator)
|
||||
b.worker(i)
|
||||
}()
|
||||
}
|
||||
defer wg.Wait()
|
||||
@ -205,7 +194,7 @@ func (b *broadcasterImpl) dispatch() {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *broadcasterImpl) worker(no int, appendOperator AppendOperator) {
|
||||
func (b *broadcasterImpl) worker(no int) {
|
||||
logger := b.Logger().With(zap.Int("workerNo", no))
|
||||
defer func() {
|
||||
logger.Info("broadcaster worker exit")
|
||||
@ -216,7 +205,7 @@ func (b *broadcasterImpl) worker(no int, appendOperator AppendOperator) {
|
||||
case <-b.backgroundTaskNotifier.Context().Done():
|
||||
return
|
||||
case task := <-b.workerChan:
|
||||
if err := task.Execute(b.backgroundTaskNotifier.Context(), appendOperator); err != nil {
|
||||
if err := task.Execute(b.backgroundTaskNotifier.Context()); err != nil {
|
||||
// If the task is not done, repush it into pendings and retry infinitely.
|
||||
select {
|
||||
case <-b.backgroundTaskNotifier.Context().Done():
|
||||
|
||||
@ -12,8 +12,9 @@ import (
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
|
||||
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
|
||||
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
|
||||
internaltypes "github.com/milvus-io/milvus/internal/types"
|
||||
@ -76,8 +77,8 @@ func TestBroadcaster(t *testing.T) {
|
||||
resource.InitForTest(resource.OptStreamingCatalog(meta), resource.OptMixCoordClient(f))
|
||||
|
||||
fbc := syncutil.NewFuture[Broadcaster]()
|
||||
operator, appended := createOpeartor(t, fbc)
|
||||
bc, err := RecoverBroadcaster(context.Background(), operator)
|
||||
appended := createOpeartor(t, fbc)
|
||||
bc, err := RecoverBroadcaster(context.Background())
|
||||
fbc.Set(bc)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bc)
|
||||
@ -135,10 +136,10 @@ func ack(broadcaster Broadcaster, broadcastID uint64, vchannel string) {
|
||||
}
|
||||
}
|
||||
|
||||
func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) (*syncutil.Future[AppendOperator], *atomic.Int64) {
|
||||
func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) *atomic.Int64 {
|
||||
id := atomic.NewInt64(1)
|
||||
appended := atomic.NewInt64(0)
|
||||
operator := mock_broadcaster.NewMockAppendOperator(t)
|
||||
operator := mock_streaming.NewMockWALAccesser(t)
|
||||
f := func(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
|
||||
resps := types.AppendResponses{
|
||||
Responses: make([]types.AppendResponse, len(msgs)),
|
||||
@ -174,9 +175,8 @@ func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) (*s
|
||||
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
|
||||
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
|
||||
|
||||
fOperator := syncutil.NewFuture[AppendOperator]()
|
||||
fOperator.Set(operator)
|
||||
return fOperator, appended
|
||||
streaming.SetWALForTest(operator)
|
||||
return appended
|
||||
}
|
||||
|
||||
func createNewBroadcastMsg(vchannels []string, rks ...message.ResourceKey) message.BroadcastMutableMessage {
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
)
|
||||
|
||||
type AppendOperatorType int
|
||||
|
||||
const (
|
||||
AppendOperatorTypeMsgstream AppendOperatorType = iota + 1
|
||||
AppendOperatorTypeStreaming
|
||||
)
|
||||
|
||||
var localRegistry = make(map[AppendOperatorType]*syncutil.Future[AppendOperator])
|
||||
|
||||
// AppendOperator is used to append messages, there's only two implement of this interface:
|
||||
// 1. streaming.WAL()
|
||||
// 2. old msgstream interface
|
||||
type AppendOperator interface {
|
||||
AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses
|
||||
}
|
||||
|
||||
func init() {
|
||||
localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]()
|
||||
localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]()
|
||||
}
|
||||
|
||||
func Register(typ AppendOperatorType, op AppendOperator) {
|
||||
localRegistry[typ].Set(op)
|
||||
}
|
||||
|
||||
func GetAppendOperator() *syncutil.Future[AppendOperator] {
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
return localRegistry[AppendOperatorTypeStreaming]
|
||||
}
|
||||
return localRegistry[AppendOperatorTypeMsgstream]
|
||||
}
|
||||
@ -3,12 +3,7 @@
|
||||
|
||||
package registry
|
||||
|
||||
import "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
|
||||
func ResetRegistration() {
|
||||
localRegistry = make(map[AppendOperatorType]*syncutil.Future[AppendOperator])
|
||||
localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]()
|
||||
localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]()
|
||||
resetMessageAckCallbacks()
|
||||
resetMessageCheckCallbacks()
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
@ -49,7 +50,7 @@ type pendingBroadcastTask struct {
|
||||
// Execute reexecute the task, return nil if the task is done, otherwise not done.
|
||||
// Execute can be repeated called until the task is done.
|
||||
// Same semantics as the `Poll` operation in eventloop.
|
||||
func (b *pendingBroadcastTask) Execute(ctx context.Context, operator AppendOperator) error {
|
||||
func (b *pendingBroadcastTask) Execute(ctx context.Context) error {
|
||||
if err := b.broadcastTask.InitializeRecovery(ctx); err != nil {
|
||||
b.Logger().Warn("broadcast task initialize recovery failed", zap.Error(err))
|
||||
b.UpdateInstantWithNextBackOff()
|
||||
@ -58,7 +59,7 @@ func (b *pendingBroadcastTask) Execute(ctx context.Context, operator AppendOpera
|
||||
|
||||
if len(b.pendingMessages) > 0 {
|
||||
b.Logger().Debug("broadcast task is polling to make sent...", zap.Int("pendingMessages", len(b.pendingMessages)))
|
||||
resps := operator.AppendMessages(ctx, b.pendingMessages...)
|
||||
resps := streaming.WAL().AppendMessages(ctx, b.pendingMessages...)
|
||||
newPendings := make([]message.MutableMessage, 0)
|
||||
for idx, resp := range resps.Responses {
|
||||
if resp.Error != nil {
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/streamingnode/client/manager"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/idalloc"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
@ -55,10 +54,8 @@ func Init(opts ...optResourceInit) {
|
||||
assertNotNil(newR.MixCoordClient())
|
||||
assertNotNil(newR.ETCD())
|
||||
assertNotNil(newR.StreamingCatalog())
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
newR.streamingNodeManagerClient = manager.NewManagerClient(newR.etcdClient)
|
||||
assertNotNil(newR.StreamingNodeManagerClient())
|
||||
}
|
||||
newR.streamingNodeManagerClient = manager.NewManagerClient(newR.etcdClient)
|
||||
assertNotNil(newR.StreamingNodeManagerClient())
|
||||
r = newR
|
||||
}
|
||||
|
||||
|
||||
@ -10,11 +10,9 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
|
||||
_ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
|
||||
"github.com/milvus-io/milvus/internal/streamingcoord/server/service"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil"
|
||||
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||
@ -53,27 +51,25 @@ func (s *Server) Start(ctx context.Context) (err error) {
|
||||
// initBasicComponent initialize all underlying dependency for streamingcoord.
|
||||
func (s *Server) initBasicComponent(ctx context.Context) (err error) {
|
||||
futures := make([]*conc.Future[struct{}], 0)
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
futures = append(futures, conc.Go(func() (struct{}, error) {
|
||||
s.logger.Info("start recovery balancer...")
|
||||
// Read new incoming topics from configuration, and register it into balancer.
|
||||
newIncomingTopics := util.GetAllTopicsFromConfiguration()
|
||||
balancer, err := balancer.RecoverBalancer(ctx, newIncomingTopics.Collect()...)
|
||||
if err != nil {
|
||||
s.logger.Warn("recover balancer failed", zap.Error(err))
|
||||
return struct{}{}, err
|
||||
}
|
||||
s.balancer.Set(balancer)
|
||||
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
|
||||
s.logger.Info("recover balancer done")
|
||||
return struct{}{}, nil
|
||||
}))
|
||||
}
|
||||
futures = append(futures, conc.Go(func() (struct{}, error) {
|
||||
s.logger.Info("start recovery balancer...")
|
||||
// Read new incoming topics from configuration, and register it into balancer.
|
||||
newIncomingTopics := util.GetAllTopicsFromConfiguration()
|
||||
balancer, err := balancer.RecoverBalancer(ctx, newIncomingTopics.Collect()...)
|
||||
if err != nil {
|
||||
s.logger.Warn("recover balancer failed", zap.Error(err))
|
||||
return struct{}{}, err
|
||||
}
|
||||
s.balancer.Set(balancer)
|
||||
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
|
||||
s.logger.Info("recover balancer done")
|
||||
return struct{}{}, nil
|
||||
}))
|
||||
// The broadcaster of msgstream is implemented on current streamingcoord to reduce the development complexity.
|
||||
// So we need to recover it.
|
||||
futures = append(futures, conc.Go(func() (struct{}, error) {
|
||||
s.logger.Info("start recovery broadcaster...")
|
||||
broadcaster, err := broadcaster.RecoverBroadcaster(ctx, registry.GetAppendOperator())
|
||||
broadcaster, err := broadcaster.RecoverBroadcaster(ctx)
|
||||
if err != nil {
|
||||
s.logger.Warn("recover broadcaster failed", zap.Error(err))
|
||||
return struct{}{}, err
|
||||
@ -87,9 +83,7 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) {
|
||||
|
||||
// RegisterGRPCService register all grpc service to grpc server.
|
||||
func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) {
|
||||
if streamingutil.IsStreamingServiceEnabled() {
|
||||
streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService)
|
||||
}
|
||||
streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService)
|
||||
streamingpb.RegisterStreamingCoordBroadcastServiceServer(grpcServer, s.broadcastService)
|
||||
}
|
||||
|
||||
|
||||
@ -84,15 +84,6 @@ type DataNodeComponent interface {
|
||||
|
||||
// SetEtcdClient set etcd client for DataNode
|
||||
SetEtcdClient(etcdClient *clientv3.Client)
|
||||
|
||||
// SetMixCoordClient set SetMixCoordClient for DataNode
|
||||
// `mixCoord` is a client of root coordinator.
|
||||
//
|
||||
// Return a generic error in status:
|
||||
// If the mixCoord is nil or the mixCoord has been set before.
|
||||
// Return nil in status:
|
||||
// The mixCoord is not nil.
|
||||
SetMixCoordClient(mixCoord MixCoordClient) error
|
||||
}
|
||||
|
||||
// DataCoordClient is the client interface for datacoord server
|
||||
@ -196,9 +187,6 @@ type ProxyComponent interface {
|
||||
|
||||
SetAddress(address string)
|
||||
GetAddress() string
|
||||
// SetEtcdClient set EtcdClient for Proxy
|
||||
// `etcdClient` is a client of etcd
|
||||
SetEtcdClient(etcdClient *clientv3.Client)
|
||||
|
||||
// SetMixCoordClient set MixCoord for Proxy
|
||||
// `mixCoord` is a client of mix coordinator.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user