enhance: remove old arch non-streaming arch code (#43651)

issue: #41609

- remove all dml dead code at proxy
- remove dead code at l0_write_buffer
- remove msgstream dependency at proxy
- remove timetick reporter from proxy
- remove replicate stream implementation

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-08-06 14:41:40 +08:00 committed by GitHub
parent 6ae727775f
commit 5551d99425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 611 additions and 4512 deletions

View File

@ -146,9 +146,7 @@ func GetMilvusRoles(args []string, flags *flag.FlagSet) *roles.MilvusRoles {
role.EnableProxy = true role.EnableProxy = true
role.EnableQueryNode = true role.EnableQueryNode = true
role.EnableDataNode = true role.EnableDataNode = true
if streamingutil.IsStreamingServiceEnabled() { role.EnableStreamingNode = true
role.EnableStreamingNode = true
}
role.Local = true role.Local = true
role.Embedded = serverType == typeutil.EmbeddedRole role.Embedded = serverType == typeutil.EmbeddedRole
case typeutil.MixCoordRole: case typeutil.MixCoordRole:

View File

@ -12,6 +12,7 @@ packages:
Utility: Utility:
Broadcast: Broadcast:
Local: Local:
Scanner:
github.com/milvus-io/milvus/internal/streamingcoord/server/balancer: github.com/milvus-io/milvus/internal/streamingcoord/server/balancer:
interfaces: interfaces:
Balancer: Balancer:

View File

@ -1462,6 +1462,9 @@ func (s *Server) GetFlushState(ctx context.Context, req *datapb.GetFlushStateReq
for _, sid := range req.GetSegmentIDs() { for _, sid := range req.GetSegmentIDs() {
segment := s.meta.GetHealthySegment(ctx, sid) segment := s.meta.GetHealthySegment(ctx, sid)
// segment is nil if it was compacted, or it's an empty segment and is set to dropped // segment is nil if it was compacted, or it's an empty segment and is set to dropped
// TODO: Here's a dirty implementation, because a growing segment may cannot be seen right away by mixcoord,
// it can only be seen by streamingnode right away, so we need to check the flush state at streamingnode but not here.
// use timetick for GetFlushState in-future but not segment list.
if segment == nil || isFlushState(segment.GetState()) { if segment == nil || isFlushState(segment.GetState()) {
continue continue
} }

View File

@ -32,13 +32,10 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
dn "github.com/milvus-io/milvus/internal/datanode" dn "github.com/milvus-io/milvus/internal/datanode"
mix "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
"github.com/milvus-io/milvus/internal/distributed/utils" "github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
_ "github.com/milvus-io/milvus/internal/util/grpcclient" _ "github.com/milvus-io/milvus/internal/util/grpcclient"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
@ -65,8 +62,6 @@ type Server struct {
factory dependency.Factory factory dependency.Factory
serverID atomic.Int64 serverID atomic.Int64
mixCoordClient func() (types.MixCoordClient, error)
} }
// NewServer new DataNode grpc server // NewServer new DataNode grpc server
@ -77,9 +72,6 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
cancel: cancel, cancel: cancel,
factory: factory, factory: factory,
grpcErrChan: make(chan error), grpcErrChan: make(chan error),
mixCoordClient: func() (types.MixCoordClient, error) {
return mix.NewClient(ctx1)
},
} }
s.serverID.Store(paramtable.GetNodeID()) s.serverID.Store(paramtable.GetNodeID())
@ -173,10 +165,6 @@ func (s *Server) SetEtcdClient(client *clientv3.Client) {
s.datanode.SetEtcdClient(client) s.datanode.SetEtcdClient(client)
} }
func (s *Server) SetMixCoordInterface(ms types.MixCoordClient) error {
return s.datanode.SetMixCoordClient(ms)
}
// Run initializes and starts Datanode's grpc service. // Run initializes and starts Datanode's grpc service.
func (s *Server) Run() error { func (s *Server) Run() error {
if err := s.init(); err != nil { if err := s.init(); err != nil {
@ -255,27 +243,6 @@ func (s *Server) init() error {
return err return err
} }
if !streamingutil.IsStreamingServiceEnabled() {
// --- MixCoord Client ---
if s.mixCoordClient != nil {
log.Info("initializing MixCoord client for DataNode")
mixCoordClient, err := s.mixCoordClient()
if err != nil {
log.Error("failed to create new MixCoord client", zap.Error(err))
panic(err)
}
if err = componentutil.WaitForComponentHealthy(s.ctx, mixCoordClient, "MixCoord", 1000000, time.Millisecond*200); err != nil {
log.Error("failed to wait for MixCoord client to be ready", zap.Error(err))
panic(err)
}
log.Info("MixCoord client is ready for DataNode")
if err = s.SetMixCoordInterface(mixCoordClient); err != nil {
panic(err)
}
}
}
s.datanode.UpdateStateCode(commonpb.StateCode_Initializing) s.datanode.UpdateStateCode(commonpb.StateCode_Initializing)
if err := s.datanode.Init(); err != nil { if err := s.datanode.Init(); err != nil {

View File

@ -27,7 +27,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb" "github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
@ -43,27 +42,10 @@ func Test_NewServer(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, server) assert.NotNil(t, server)
mockMixCoord := mocks.NewMockMixCoordClient(t)
mockMixCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
Status: merr.Success(),
SubcomponentStates: []*milvuspb.ComponentInfo{
{
StateCode: commonpb.StateCode_Healthy,
},
},
}, nil)
server.mixCoordClient = func() (types.MixCoordClient, error) {
return mockMixCoord, nil
}
t.Run("Run", func(t *testing.T) { t.Run("Run", func(t *testing.T) {
datanode := mocks.NewMockDataNode(t) datanode := mocks.NewMockDataNode(t)
datanode.EXPECT().SetEtcdClient(mock.Anything).Return() datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
datanode.EXPECT().SetAddress(mock.Anything).Return() datanode.EXPECT().SetAddress(mock.Anything).Return()
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
datanode.EXPECT().UpdateStateCode(mock.Anything).Return() datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
datanode.EXPECT().Register().Return(nil) datanode.EXPECT().Register().Return(nil)
datanode.EXPECT().Init().Return(nil) datanode.EXPECT().Init().Return(nil)
@ -191,26 +173,9 @@ func Test_Run(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, server) assert.NotNil(t, server)
mockRootCoord := mocks.NewMockMixCoordClient(t)
mockRootCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
Status: merr.Success(),
SubcomponentStates: []*milvuspb.ComponentInfo{
{
StateCode: commonpb.StateCode_Healthy,
},
},
}, nil)
server.mixCoordClient = func() (types.MixCoordClient, error) {
return mockRootCoord, nil
}
datanode := mocks.NewMockDataNode(t) datanode := mocks.NewMockDataNode(t)
datanode.EXPECT().SetEtcdClient(mock.Anything).Return() datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
datanode.EXPECT().SetAddress(mock.Anything).Return() datanode.EXPECT().SetAddress(mock.Anything).Return()
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
datanode.EXPECT().UpdateStateCode(mock.Anything).Return() datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
datanode.EXPECT().Init().Return(errors.New("mock err")) datanode.EXPECT().Init().Return(errors.New("mock err"))
server.datanode = datanode server.datanode = datanode
@ -223,7 +188,6 @@ func Test_Run(t *testing.T) {
datanode = mocks.NewMockDataNode(t) datanode = mocks.NewMockDataNode(t)
datanode.EXPECT().SetEtcdClient(mock.Anything).Return() datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
datanode.EXPECT().SetAddress(mock.Anything).Return() datanode.EXPECT().SetAddress(mock.Anything).Return()
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
datanode.EXPECT().UpdateStateCode(mock.Anything).Return() datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
datanode.EXPECT().Register().Return(nil) datanode.EXPECT().Register().Return(nil)
datanode.EXPECT().Init().Return(nil) datanode.EXPECT().Init().Return(nil)
@ -242,26 +206,9 @@ func TestIndexService(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, server) assert.NotNil(t, server)
mockRootCoord := mocks.NewMockMixCoordClient(t)
mockRootCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
Status: merr.Success(),
SubcomponentStates: []*milvuspb.ComponentInfo{
{
StateCode: commonpb.StateCode_Healthy,
},
},
}, nil)
server.mixCoordClient = func() (types.MixCoordClient, error) {
return mockRootCoord, nil
}
dn := mocks.NewMockDataNode(t) dn := mocks.NewMockDataNode(t)
dn.EXPECT().SetEtcdClient(mock.Anything).Return() dn.EXPECT().SetEtcdClient(mock.Anything).Return()
dn.EXPECT().SetAddress(mock.Anything).Return() dn.EXPECT().SetAddress(mock.Anything).Return()
dn.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
dn.EXPECT().UpdateStateCode(mock.Anything).Return() dn.EXPECT().UpdateStateCode(mock.Anything).Return()
dn.EXPECT().Register().Return(nil) dn.EXPECT().Register().Return(nil)
dn.EXPECT().Init().Return(nil) dn.EXPECT().Init().Return(nil)

View File

@ -449,7 +449,6 @@ func (s *Server) init() error {
return err return err
} }
s.etcdCli = etcdCli s.etcdCli = etcdCli
s.proxy.SetEtcdClient(s.etcdCli)
s.proxy.SetAddress(s.listenerManager.internalGrpcListener.Address()) s.proxy.SetAddress(s.listenerManager.internalGrpcListener.Address())
errChan := make(chan error, 1) errChan := make(chan error, 1)

View File

@ -201,7 +201,6 @@ func Test_NewServer(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil) mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -687,7 +686,6 @@ func Test_NewServer(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil) mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -821,7 +819,6 @@ func Test_NewServer_HTTPServer_Enabled(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil) mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -886,7 +883,6 @@ func Test_NewServer_TLS_TwoWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil) mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -914,7 +910,6 @@ func Test_NewServer_TLS_OneWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil) mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -938,7 +933,6 @@ func Test_NewServer_TLS_FileNotExisted(t *testing.T) {
mockProxy := server.proxy.(*mocks.MockProxy) mockProxy := server.proxy.(*mocks.MockProxy)
mockProxy.EXPECT().Stop().Return(nil) mockProxy.EXPECT().Stop().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetAddress(mock.Anything).Return() mockProxy.EXPECT().SetAddress(mock.Anything).Return()
@ -977,7 +971,6 @@ func Test_NewHTTPServer_TLS_TwoWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil) mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -1013,7 +1006,6 @@ func Test_NewHTTPServer_TLS_OneWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil) mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -1046,7 +1038,6 @@ func Test_NewHTTPServer_TLS_FileNotExisted(t *testing.T) {
mockProxy := server.proxy.(*mocks.MockProxy) mockProxy := server.proxy.(*mocks.MockProxy)
mockProxy.EXPECT().Stop().Return(nil) mockProxy.EXPECT().Stop().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return().Maybe()
mockProxy.EXPECT().SetAddress(mock.Anything).Return().Maybe() mockProxy.EXPECT().SetAddress(mock.Anything).Return().Maybe()
Params := &paramtable.Get().ProxyGrpcServerCfg Params := &paramtable.Get().ProxyGrpcServerCfg
@ -1160,7 +1151,6 @@ func Test_Service_GracefulStop(t *testing.T) {
mockProxy.EXPECT().Start().Return(nil) mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Stop().Return(nil) mockProxy.EXPECT().Stop().Return(nil)
mockProxy.EXPECT().Register().Return(nil) mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil) mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return() mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return() mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"time" "time"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/options" "github.com/milvus-io/milvus/pkg/v2/streaming/util/options"
@ -18,8 +17,6 @@ var singleton WALAccesser = nil
func Init() { func Init() {
c, _ := kvfactory.GetEtcdAndPath() c, _ := kvfactory.GetEtcdAndPath()
singleton = newWALAccesser(c) singleton = newWALAccesser(c)
// Add the wal accesser to the broadcaster registry for making broadcast operation.
registry.Register(registry.AppendOperatorTypeStreaming, singleton)
} }
// Release releases the resources of the wal accesser. // Release releases the resources of the wal accesser.

View File

@ -18,7 +18,20 @@
package streaming package streaming
import kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" import (
"context"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/types/known/anypb"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
var expectErr = make(chan error, 10)
// SetWALForTest initializes the singleton of wal for test. // SetWALForTest initializes the singleton of wal for test.
func SetWALForTest(w WALAccesser) { func SetWALForTest(w WALAccesser) {
@ -29,3 +42,146 @@ func RecoverWALForTest() {
c, _ := kvfactory.GetEtcdAndPath() c, _ := kvfactory.GetEtcdAndPath()
singleton = newWALAccesser(c) singleton = newWALAccesser(c)
} }
func ExpectErrorOnce(err error) {
expectErr <- err
}
func SetupNoopWALForTest() {
singleton = &noopWALAccesser{}
}
type noopLocal struct{}
func (n *noopLocal) GetLatestMVCCTimestampIfLocal(ctx context.Context, vchannel string) (uint64, error) {
return 0, errors.New("not implemented")
}
func (n *noopLocal) GetMetricsIfLocal(ctx context.Context) (*types.StreamingNodeMetrics, error) {
return &types.StreamingNodeMetrics{}, nil
}
type noopBroadcast struct{}
func (n *noopBroadcast) Append(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
return &types.BroadcastAppendResult{
BroadcastID: 1,
AppendResults: map[string]*types.AppendResult{
"v1": {
MessageID: rmq.NewRmqID(1),
TimeTick: 10,
Extra: &anypb.Any{},
},
},
}, nil
}
func (n *noopBroadcast) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
return nil
}
type noopTxn struct{}
func (n *noopTxn) Append(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) error {
if err := getExpectErr(); err != nil {
return err
}
return nil
}
func (n *noopTxn) Commit(ctx context.Context) (*types.AppendResult, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
return &types.AppendResult{}, nil
}
func (n *noopTxn) Rollback(ctx context.Context) error {
if err := getExpectErr(); err != nil {
return err
}
return nil
}
type noopWALAccesser struct{}
func (n *noopWALAccesser) WALName() string {
return "noop"
}
func (n *noopWALAccesser) Local() Local {
return &noopLocal{}
}
func (n *noopWALAccesser) Txn(ctx context.Context, opts TxnOption) (Txn, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
return &noopTxn{}, nil
}
func (n *noopWALAccesser) RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
extra, _ := anypb.New(&messagespb.ManualFlushExtraResponse{
SegmentIds: []int64{1, 2, 3},
})
return &types.AppendResult{
MessageID: rmq.NewRmqID(1),
TimeTick: 10,
Extra: extra,
}, nil
}
func (n *noopWALAccesser) Broadcast() Broadcast {
return &noopBroadcast{}
}
func (n *noopWALAccesser) Read(ctx context.Context, opts ReadOption) Scanner {
return &noopScanner{}
}
func (n *noopWALAccesser) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses {
if err := getExpectErr(); err != nil {
return AppendResponses{
Responses: []AppendResponse{
{
AppendResult: nil,
Error: err,
},
},
}
}
return AppendResponses{}
}
func (n *noopWALAccesser) AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses {
return AppendResponses{}
}
type noopScanner struct{}
func (n *noopScanner) Done() <-chan struct{} {
return make(chan struct{})
}
func (n *noopScanner) Error() error {
return nil
}
func (n *noopScanner) Close() {
}
// getExpectErr is a helper function to get the error from the expectErr channel.
func getExpectErr() error {
select {
case err := <-expectErr:
return err
default:
return nil
}
}

View File

@ -11,7 +11,6 @@ import (
"github.com/milvus-io/milvus/internal/distributed/streaming/internal/producer" "github.com/milvus-io/milvus/internal/distributed/streaming/internal/producer"
"github.com/milvus-io/milvus/internal/streamingcoord/client" "github.com/milvus-io/milvus/internal/streamingcoord/client"
"github.com/milvus-io/milvus/internal/streamingnode/client/handler" "github.com/milvus-io/milvus/internal/streamingnode/client/handler"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
@ -29,11 +28,7 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl {
// Create a new streaming coord client. // Create a new streaming coord client.
streamingCoordClient := client.NewClient(c) streamingCoordClient := client.NewClient(c)
// Create a new streamingnode handler client. // Create a new streamingnode handler client.
var handlerClient handler.HandlerClient handlerClient := handler.NewHandlerClient(streamingCoordClient.Assignment())
if streamingutil.IsStreamingServiceEnabled() {
// streaming service is enabled, create the handler client for the streaming service.
handlerClient = handler.NewHandlerClient(streamingCoordClient.Assignment())
}
w := &walAccesserImpl{ w := &walAccesserImpl{
lifetime: typeutil.NewLifetime(), lifetime: typeutil.NewLifetime(),
streamingCoordClient: streamingCoordClient, streamingCoordClient: streamingCoordClient,

View File

@ -4,27 +4,21 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/flushcommon/metacache" "github.com/milvus-io/milvus/internal/flushcommon/metacache"
"github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle" "github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle"
"github.com/milvus-io/milvus/internal/flushcommon/syncmgr" "github.com/milvus-io/milvus/internal/flushcommon/syncmgr"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry" "github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
type l0WriteBuffer struct { type l0WriteBuffer struct {
@ -54,94 +48,6 @@ func NewL0WriteBuffer(channel string, metacache metacache.MetaCache, syncMgr syn
}, nil }, nil
} }
func (wb *l0WriteBuffer) dispatchDeleteMsgs(groups []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt()
split := func(pks []storage.PrimaryKey, pkTss []uint64, partitionSegments []*metacache.SegmentInfo, partitionGroups []*InsertData) []bool {
lc := storage.NewBatchLocationsCache(pks)
// use hits to cache result
hits := make([]bool, len(pks))
for _, segment := range partitionSegments {
hits = segment.GetBloomFilterSet().BatchPkExistWithHits(lc, hits)
}
for _, inData := range partitionGroups {
hits = inData.batchPkExists(pks, pkTss, hits)
}
return hits
}
type BatchApplyRet = struct {
// represent the idx for delete msg in deleteMsgs
DeleteDataIdx int
// represent the start idx for the batch in each deleteMsg
StartIdx int
Hits []bool
}
// transform pk to primary key
pksInDeleteMsgs := lo.Map(deleteMsgs, func(delMsg *msgstream.DeleteMsg, _ int) []storage.PrimaryKey {
return storage.ParseIDs2PrimaryKeys(delMsg.GetPrimaryKeys())
})
retIdx := 0
retMap := typeutil.NewConcurrentMap[int, *BatchApplyRet]()
pool := io.GetBFApplyPool()
var futures []*conc.Future[any]
for didx, delMsg := range deleteMsgs {
pks := pksInDeleteMsgs[didx]
pkTss := delMsg.GetTimestamps()
partitionSegments := wb.metaCache.GetSegmentsBy(metacache.WithPartitionID(delMsg.PartitionID),
metacache.WithSegmentState(commonpb.SegmentState_Growing, commonpb.SegmentState_Sealed, commonpb.SegmentState_Flushing, commonpb.SegmentState_Flushed))
partitionGroups := lo.Filter(groups, func(inData *InsertData, _ int) bool {
return delMsg.GetPartitionID() == common.AllPartitionsID || delMsg.GetPartitionID() == inData.partitionID
})
for idx := 0; idx < len(pks); idx += batchSize {
startIdx := idx
endIdx := idx + batchSize
if endIdx > len(pks) {
endIdx = len(pks)
}
retIdx += 1
tmpRetIdx := retIdx
deleteDataId := didx
future := pool.Submit(func() (any, error) {
hits := split(pks[startIdx:endIdx], pkTss[startIdx:endIdx], partitionSegments, partitionGroups)
retMap.Insert(tmpRetIdx, &BatchApplyRet{
DeleteDataIdx: deleteDataId,
StartIdx: startIdx,
Hits: hits,
})
return nil, nil
})
futures = append(futures, future)
}
}
conc.AwaitAll(futures...)
retMap.Range(func(key int, value *BatchApplyRet) bool {
l0SegmentID := wb.getL0SegmentID(deleteMsgs[value.DeleteDataIdx].GetPartitionID(), startPos)
pks := pksInDeleteMsgs[value.DeleteDataIdx]
pkTss := deleteMsgs[value.DeleteDataIdx].GetTimestamps()
var deletePks []storage.PrimaryKey
var deleteTss []typeutil.Timestamp
for i, hit := range value.Hits {
if hit {
deletePks = append(deletePks, pks[value.StartIdx+i])
deleteTss = append(deleteTss, pkTss[value.StartIdx+i])
}
}
if len(deletePks) > 0 {
wb.bufferDelete(l0SegmentID, deletePks, deleteTss, startPos, endPos)
}
return true
})
}
func (wb *l0WriteBuffer) dispatchDeleteMsgsWithoutFilter(deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) { func (wb *l0WriteBuffer) dispatchDeleteMsgsWithoutFilter(deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
for _, msg := range deleteMsgs { for _, msg := range deleteMsgs {
l0SegmentID := wb.getL0SegmentID(msg.GetPartitionID(), startPos) l0SegmentID := wb.getL0SegmentID(msg.GetPartitionID(), startPos)
@ -165,31 +71,10 @@ func (wb *l0WriteBuffer) BufferData(insertData []*InsertData, deleteMsgs []*msgs
} }
} }
if paramtable.Get().DataNodeCfg.SkipBFStatsLoad.GetAsBool() || streamingutil.IsStreamingServiceEnabled() { // In streaming service mode, flushed segments no longer maintain a bloom filter.
// In streaming service mode, flushed segments no longer maintain a bloom filter. // So, here we skip generating BF (growing segment's BF will be regenerated during the sync phase)
// So, here we skip generating BF (growing segment's BF will be regenerated during the sync phase) // and also skip filtering delete entries by bf.
// and also skip filtering delete entries by bf. wb.dispatchDeleteMsgsWithoutFilter(deleteMsgs, startPos, endPos)
wb.dispatchDeleteMsgsWithoutFilter(deleteMsgs, startPos, endPos)
} else {
// distribute delete msg
// bf write buffer check bloom filter of segment and current insert batch to decide which segment to write delete data
wb.dispatchDeleteMsgs(insertData, deleteMsgs, startPos, endPos)
// update pk oracle
for _, inData := range insertData {
// segment shall always exists after buffer insert
segments := wb.metaCache.GetSegmentsBy(metacache.WithSegmentIDs(inData.segmentID))
for _, segment := range segments {
for _, fieldData := range inData.pkField {
err := segment.GetBloomFilterSet().UpdatePKRange(fieldData)
if err != nil {
return err
}
}
}
}
}
// update buffer last checkpoint // update buffer last checkpoint
wb.checkpoint = endPos wb.checkpoint = endPos

View File

@ -18,7 +18,6 @@ import (
"github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle" "github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle"
"github.com/milvus-io/milvus/internal/flushcommon/syncmgr" "github.com/milvus-io/milvus/internal/flushcommon/syncmgr"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
@ -324,10 +323,8 @@ func (wb *writeBufferBase) syncSegments(ctx context.Context, segmentIDs []int64)
} }
if syncTask.IsFlush() { if syncTask.IsFlush() {
if paramtable.Get().DataNodeCfg.SkipBFStatsLoad.GetAsBool() || streamingutil.IsStreamingServiceEnabled() { wb.metaCache.RemoveSegments(metacache.WithSegmentIDs(syncTask.SegmentID()))
wb.metaCache.RemoveSegments(metacache.WithSegmentIDs(syncTask.SegmentID())) log.Info("flushed segment removed", zap.Int64("segmentID", syncTask.SegmentID()), zap.String("channel", syncTask.ChannelName()))
log.Info("flushed segment removed", zap.Int64("segmentID", syncTask.SegmentID()), zap.String("channel", syncTask.ChannelName()))
}
} }
return nil return nil
}) })

View File

@ -0,0 +1,156 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_streaming
import mock "github.com/stretchr/testify/mock"
// MockScanner is an autogenerated mock type for the Scanner type
type MockScanner struct {
mock.Mock
}
type MockScanner_Expecter struct {
mock *mock.Mock
}
func (_m *MockScanner) EXPECT() *MockScanner_Expecter {
return &MockScanner_Expecter{mock: &_m.Mock}
}
// Close provides a mock function with no fields
func (_m *MockScanner) Close() {
_m.Called()
}
// MockScanner_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockScanner_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockScanner_Expecter) Close() *MockScanner_Close_Call {
return &MockScanner_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockScanner_Close_Call) Run(run func()) *MockScanner_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScanner_Close_Call) Return() *MockScanner_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockScanner_Close_Call) RunAndReturn(run func()) *MockScanner_Close_Call {
_c.Run(run)
return _c
}
// Done provides a mock function with no fields
func (_m *MockScanner) Done() <-chan struct{} {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Done")
}
var r0 <-chan struct{}
if rf, ok := ret.Get(0).(func() <-chan struct{}); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan struct{})
}
}
return r0
}
// MockScanner_Done_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Done'
type MockScanner_Done_Call struct {
*mock.Call
}
// Done is a helper method to define mock.On call
func (_e *MockScanner_Expecter) Done() *MockScanner_Done_Call {
return &MockScanner_Done_Call{Call: _e.mock.On("Done")}
}
func (_c *MockScanner_Done_Call) Run(run func()) *MockScanner_Done_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScanner_Done_Call) Return(_a0 <-chan struct{}) *MockScanner_Done_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockScanner_Done_Call) RunAndReturn(run func() <-chan struct{}) *MockScanner_Done_Call {
_c.Call.Return(run)
return _c
}
// Error provides a mock function with no fields
func (_m *MockScanner) Error() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Error")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MockScanner_Error_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Error'
type MockScanner_Error_Call struct {
*mock.Call
}
// Error is a helper method to define mock.On call
func (_e *MockScanner_Expecter) Error() *MockScanner_Error_Call {
return &MockScanner_Error_Call{Call: _e.mock.On("Error")}
}
func (_c *MockScanner_Error_Call) Run(run func()) *MockScanner_Error_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScanner_Error_Call) Return(_a0 error) *MockScanner_Error_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockScanner_Error_Call) RunAndReturn(run func() error) *MockScanner_Error_Call {
_c.Call.Return(run)
return _c
}
// NewMockScanner creates a new instance of MockScanner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockScanner(t interface {
mock.TestingT
Cleanup(func())
}) *MockScanner {
mock := &MockScanner{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -16,8 +16,6 @@ import (
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
types "github.com/milvus-io/milvus/internal/types"
workerpb "github.com/milvus-io/milvus/pkg/v2/proto/workerpb" workerpb "github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
) )
@ -1918,52 +1916,6 @@ func (_c *MockDataNode_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clien
return _c return _c
} }
// SetMixCoordClient provides a mock function with given fields: mixCoord
func (_m *MockDataNode) SetMixCoordClient(mixCoord types.MixCoordClient) error {
ret := _m.Called(mixCoord)
if len(ret) == 0 {
panic("no return value specified for SetMixCoordClient")
}
var r0 error
if rf, ok := ret.Get(0).(func(types.MixCoordClient) error); ok {
r0 = rf(mixCoord)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockDataNode_SetMixCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMixCoordClient'
type MockDataNode_SetMixCoordClient_Call struct {
*mock.Call
}
// SetMixCoordClient is a helper method to define mock.On call
// - mixCoord types.MixCoordClient
func (_e *MockDataNode_Expecter) SetMixCoordClient(mixCoord interface{}) *MockDataNode_SetMixCoordClient_Call {
return &MockDataNode_SetMixCoordClient_Call{Call: _e.mock.On("SetMixCoordClient", mixCoord)}
}
func (_c *MockDataNode_SetMixCoordClient_Call) Run(run func(mixCoord types.MixCoordClient)) *MockDataNode_SetMixCoordClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(types.MixCoordClient))
})
return _c
}
func (_c *MockDataNode_SetMixCoordClient_Call) Return(_a0 error) *MockDataNode_SetMixCoordClient_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockDataNode_SetMixCoordClient_Call) RunAndReturn(run func(types.MixCoordClient) error) *MockDataNode_SetMixCoordClient_Call {
_c.Call.Return(run)
return _c
}
// ShowConfigurations provides a mock function with given fields: _a0, _a1 // ShowConfigurations provides a mock function with given fields: _a0, _a1
func (_m *MockDataNode) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { func (_m *MockDataNode) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)

View File

@ -6,7 +6,6 @@ import (
context "context" context "context"
commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
clientv3 "go.etcd.io/etcd/client/v3"
federpb "github.com/milvus-io/milvus-proto/go-api/v2/federpb" federpb "github.com/milvus-io/milvus-proto/go-api/v2/federpb"
@ -6690,39 +6689,6 @@ func (_c *MockProxy_SetAddress_Call) RunAndReturn(run func(string)) *MockProxy_S
return _c return _c
} }
// SetEtcdClient provides a mock function with given fields: etcdClient
func (_m *MockProxy) SetEtcdClient(etcdClient *clientv3.Client) {
_m.Called(etcdClient)
}
// MockProxy_SetEtcdClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetEtcdClient'
type MockProxy_SetEtcdClient_Call struct {
*mock.Call
}
// SetEtcdClient is a helper method to define mock.On call
// - etcdClient *clientv3.Client
func (_e *MockProxy_Expecter) SetEtcdClient(etcdClient interface{}) *MockProxy_SetEtcdClient_Call {
return &MockProxy_SetEtcdClient_Call{Call: _e.mock.On("SetEtcdClient", etcdClient)}
}
func (_c *MockProxy_SetEtcdClient_Call) Run(run func(etcdClient *clientv3.Client)) *MockProxy_SetEtcdClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*clientv3.Client))
})
return _c
}
func (_c *MockProxy_SetEtcdClient_Call) Return() *MockProxy_SetEtcdClient_Call {
_c.Call.Return()
return _c
}
func (_c *MockProxy_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Client)) *MockProxy_SetEtcdClient_Call {
_c.Run(run)
return _c
}
// SetMixCoordClient provides a mock function with given fields: rootCoord // SetMixCoordClient provides a mock function with given fields: rootCoord
func (_m *MockProxy) SetMixCoordClient(rootCoord types.MixCoordClient) { func (_m *MockProxy) SetMixCoordClient(rootCoord types.MixCoordClient) {
_m.Called(rootCoord) _m.Called(rootCoord)

View File

@ -39,9 +39,7 @@ import (
type channelsMgr interface { type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error) getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error) getVChannels(collectionID UniqueID) ([]vChan, error)
getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
removeDMLStream(collectionID UniqueID) removeDMLStream(collectionID UniqueID)
removeAllDMLStream()
} }
type channelInfos struct { type channelInfos struct {
@ -52,7 +50,6 @@ type channelInfos struct {
type streamInfos struct { type streamInfos struct {
channelInfos channelInfos channelInfos channelInfos
stream msgstream.MsgStream
} }
func removeDuplicate(ss []string) []string { func removeDuplicate(ss []string) []string {
@ -114,9 +111,8 @@ type singleTypeChannelsMgr struct {
infos map[UniqueID]streamInfos // collection id -> stream infos infos map[UniqueID]streamInfos // collection id -> stream infos
mu sync.RWMutex mu sync.RWMutex
getChannelsFunc getChannelsFuncType getChannelsFunc getChannelsFuncType
repackFunc repackFuncType repackFunc repackFuncType
msgStreamFactory msgstream.Factory
} }
func (mgr *singleTypeChannelsMgr) getAllChannels(collectionID UniqueID) (channelInfos, error) { func (mgr *singleTypeChannelsMgr) getAllChannels(collectionID UniqueID) (channelInfos, error) {
@ -167,27 +163,6 @@ func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan,
return channelInfos.vchans, nil return channelInfos.vchans, nil
} }
func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool {
streamInfos, ok := mgr.infos[collectionID]
return ok && streamInfos.stream != nil
}
func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream
var err error
stream, err = factory.NewMsgStream(context.Background())
if err != nil {
return nil, err
}
stream.AsProducer(ctx, pchans)
if repack != nil {
stream.SetRepackFunc(repack)
}
return stream, nil
}
func incPChansMetrics(pchans []pChan) { func incPChansMetrics(pchans []pChan) {
for _, pc := range pchans { for _, pc := range pchans {
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), pc).Inc() metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), pc).Inc()
@ -200,67 +175,6 @@ func decPChanMetrics(pchans []pChan) {
} }
} }
// createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return it and no error will be returned.
func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
infos, ok := mgr.infos[collectionID]
if ok && infos.stream != nil {
// already exist.
mgr.mu.RUnlock()
return infos.stream, nil
}
mgr.mu.RUnlock()
channelInfos, err := mgr.getChannelsFunc(collectionID)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to get channels", zap.Error(err), zap.Int64("collection", collectionID))
return nil, err
}
stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
return nil, err
}
mgr.mu.Lock()
defer mgr.mu.Unlock()
if !mgr.streamExistPrivate(collectionID) {
log.Info("create message stream", zap.Int64("collection", collectionID),
zap.Strings("virtual_channels", channelInfos.vchans),
zap.Strings("physical_channels", channelInfos.pchans))
mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream}
incPChansMetrics(channelInfos.pchans)
} else {
stream.Close()
}
return mgr.infos[collectionID].stream, nil
}
func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
streamInfos, ok := mgr.infos[collectionID]
if ok {
return streamInfos.stream, nil
}
return nil, fmt.Errorf("collection not found: %d", collectionID)
}
// getOrCreateStream get message stream of specified collection.
// If stream doesn't exist, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil
}
return mgr.createMsgStream(ctx, collectionID)
}
// removeStream remove the corresponding stream of the specified collection. Idempotent. // removeStream remove the corresponding stream of the specified collection. Idempotent.
// If stream already exists, remove it, otherwise do nothing. // If stream already exists, remove it, otherwise do nothing.
func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) { func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) {
@ -268,34 +182,19 @@ func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) {
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
if info, ok := mgr.infos[collectionID]; ok { if info, ok := mgr.infos[collectionID]; ok {
decPChanMetrics(info.channelInfos.pchans) decPChanMetrics(info.channelInfos.pchans)
info.stream.Close()
delete(mgr.infos, collectionID) delete(mgr.infos, collectionID)
} }
log.Info("dml stream removed", zap.Int64("collection_id", collectionID)) log.Info("dml stream removed", zap.Int64("collection_id", collectionID))
} }
// removeAllStream remove all message stream.
func (mgr *singleTypeChannelsMgr) removeAllStream() {
mgr.mu.Lock()
defer mgr.mu.Unlock()
for _, info := range mgr.infos {
info.stream.Close()
decPChanMetrics(info.channelInfos.pchans)
}
mgr.infos = make(map[UniqueID]streamInfos)
log.Info("all dml stream removed")
}
func newSingleTypeChannelsMgr( func newSingleTypeChannelsMgr(
getChannelsFunc getChannelsFuncType, getChannelsFunc getChannelsFuncType,
msgStreamFactory msgstream.Factory,
repackFunc repackFuncType, repackFunc repackFuncType,
) *singleTypeChannelsMgr { ) *singleTypeChannelsMgr {
return &singleTypeChannelsMgr{ return &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos), infos: make(map[UniqueID]streamInfos),
getChannelsFunc: getChannelsFunc, getChannelsFunc: getChannelsFunc,
repackFunc: repackFunc, repackFunc: repackFunc,
msgStreamFactory: msgStreamFactory,
} }
} }
@ -315,25 +214,16 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
return mgr.dmlChannelsMgr.getVChannels(collectionID) return mgr.dmlChannelsMgr.getVChannels(collectionID)
} }
func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
}
func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) { func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {
mgr.dmlChannelsMgr.removeStream(collectionID) mgr.dmlChannelsMgr.removeStream(collectionID)
} }
func (mgr *channelsMgrImpl) removeAllDMLStream() {
mgr.dmlChannelsMgr.removeAllStream()
}
// newChannelsMgrImpl constructs a channels manager. // newChannelsMgrImpl constructs a channels manager.
func newChannelsMgrImpl( func newChannelsMgrImpl(
getDmlChannelsFunc getChannelsFuncType, getDmlChannelsFunc getChannelsFuncType,
dmlRepackFunc repackFuncType, dmlRepackFunc repackFuncType,
msgStreamFactory msgstream.Factory,
) *channelsMgrImpl { ) *channelsMgrImpl {
return &channelsMgrImpl{ return &channelsMgrImpl{
dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, msgStreamFactory, dmlRepackFunc), dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, dmlRepackFunc),
} }
} }

View File

@ -18,7 +18,6 @@ package proxy
import ( import (
"context" "context"
"sync"
"testing" "testing"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
@ -28,8 +27,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
) )
func Test_removeDuplicate(t *testing.T) { func Test_removeDuplicate(t *testing.T) {
@ -205,220 +202,11 @@ func Test_singleTypeChannelsMgr_getVChannels(t *testing.T) {
}) })
} }
func Test_createStream(t *testing.T) {
t.Run("failed to create msgstream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})
t.Run("failed to create query msgstream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
_, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return nil, nil
})
assert.NoError(t, err)
})
}
func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
paramtable.Init()
t.Run("re-create", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
t.Run("concurrent create", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
stopCh := make(chan struct{})
readyCh := make(chan struct{})
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
close(readyCh)
<-stopCh
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
firstStream := streamInfos{stream: newMockMsgStream()}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
}()
// make sure create msg stream has run at getchannels
<-readyCh
// mock create stream for same collection in same time.
m.mu.Lock()
m.infos[100] = firstStream
m.mu.Unlock()
close(stopCh)
wg.Wait()
})
t.Run("failed to get channels", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})
t.Run("failed to create message stream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
stream, err = m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_lockGetStream(t *testing.T) {
t.Run("collection not found", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
}
_, err := m.lockGetStream(100)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.lockGetStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
t.Run("exist", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
t.Run("failed to create", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{},
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getOrCreateStream(context.TODO(), 100)
assert.Error(t, err)
})
t.Run("get after create", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_removeStream(t *testing.T) { func Test_singleTypeChannelsMgr_removeStream(t *testing.T) {
m := &singleTypeChannelsMgr{ m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{ infos: map[UniqueID]streamInfos{
100: { 100: {},
stream: newMockMsgStream(),
},
}, },
} }
m.removeStream(100) m.removeStream(100)
_, err := m.lockGetStream(100)
assert.Error(t, err)
}
func Test_singleTypeChannelsMgr_removeAllStream(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {
stream: newMockMsgStream(),
},
},
}
m.removeAllStream()
_, err := m.lockGetStream(100)
assert.Error(t, err)
} }

View File

@ -18,7 +18,6 @@ package proxy
import ( import (
"context" "context"
"encoding/base64"
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
@ -46,7 +45,6 @@ import (
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/ctokenizer" "github.com/milvus-io/milvus/internal/util/ctokenizer"
"github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
@ -255,7 +253,6 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
CreateDatabaseRequest: request, CreateDatabaseRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -323,7 +320,6 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
DropDatabaseRequest: request, DropDatabaseRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -452,7 +448,6 @@ func (node *Proxy) AlterDatabase(ctx context.Context, request *milvuspb.AlterDat
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
AlterDatabaseRequest: request, AlterDatabaseRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -667,7 +662,6 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
DropCollectionRequest: request, DropCollectionRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
chMgr: node.chMgr, chMgr: node.chMgr,
chTicker: node.chTicker,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -833,7 +827,6 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
LoadCollectionRequest: request, LoadCollectionRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -908,7 +901,6 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
ReleaseCollectionRequest: request, ReleaseCollectionRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -1278,7 +1270,6 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
AlterCollectionRequest: request, AlterCollectionRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -1343,7 +1334,6 @@ func (node *Proxy) AlterCollectionField(ctx context.Context, request *milvuspb.A
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
AlterCollectionFieldRequest: request, AlterCollectionFieldRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -1616,7 +1606,6 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
LoadPartitionsRequest: request, LoadPartitionsRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
log := log.Ctx(ctx).With( log := log.Ctx(ctx).With(
@ -1682,7 +1671,6 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
ReleasePartitionsRequest: request, ReleasePartitionsRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
method := "ReleasePartitions" method := "ReleasePartitions"
@ -2082,11 +2070,10 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
defer sp.End() defer sp.End()
cit := &createIndexTask{ cit := &createIndexTask{
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
req: request, req: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
method := "CreateIndex" method := "CreateIndex"
@ -2152,11 +2139,10 @@ func (node *Proxy) AlterIndex(ctx context.Context, request *milvuspb.AlterIndexR
defer sp.End() defer sp.End()
task := &alterIndexTask{ task := &alterIndexTask{
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
req: request, req: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
method := "AlterIndex" method := "AlterIndex"
@ -2370,11 +2356,10 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq
defer sp.End() defer sp.End()
dit := &dropIndexTask{ dit := &dropIndexTask{
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
DropIndexRequest: request, DropIndexRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
method := "DropIndex" method := "DropIndex"
@ -2630,17 +2615,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
}, },
}, },
idAllocator: node.rowIDAllocator, idAllocator: node.rowIDAllocator,
segIDAssigner: node.segAssigner,
chMgr: node.chMgr, chMgr: node.chMgr,
chTicker: node.chTicker,
schemaTimestamp: request.SchemaTimestamp, schemaTimestamp: request.SchemaTimestamp,
} }
var enqueuedTask task = it
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &insertTaskByStreamingService{
insertTask: it,
}
}
constructFailedResponse := func(err error) *milvuspb.MutationResult { constructFailedResponse := func(err error) *milvuspb.MutationResult {
numRows := request.NumRows numRows := request.NumRows
@ -2657,7 +2634,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
log.Debug("Enqueue insert request in Proxy") log.Debug("Enqueue insert request in Proxy")
if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil { if err := node.sched.dmQueue.Enqueue(it); err != nil {
log.Warn("Failed to enqueue insert task: " + err.Error()) log.Warn("Failed to enqueue insert task: " + err.Error())
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.AbandonLabel, request.GetDbName(), request.GetCollectionName()).Inc() metrics.AbandonLabel, request.GetDbName(), request.GetCollectionName()).Inc()
@ -2765,7 +2742,6 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
idAllocator: node.rowIDAllocator, idAllocator: node.rowIDAllocator,
tsoAllocatorIns: node.tsoAllocator, tsoAllocatorIns: node.tsoAllocator,
chMgr: node.chMgr, chMgr: node.chMgr,
chTicker: node.chTicker,
queue: node.sched.dmQueue, queue: node.sched.dmQueue,
lb: node.lbPolicy, lb: node.lbPolicy,
limiter: limiter, limiter: limiter,
@ -2874,23 +2850,15 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
}, },
idAllocator: node.rowIDAllocator, idAllocator: node.rowIDAllocator,
segIDAssigner: node.segAssigner,
chMgr: node.chMgr, chMgr: node.chMgr,
chTicker: node.chTicker,
schemaTimestamp: request.SchemaTimestamp, schemaTimestamp: request.SchemaTimestamp,
} }
var enqueuedTask task = it
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &upsertTaskByStreamingService{
upsertTask: it,
}
}
log.Debug("Enqueue upsert request in Proxy", log.Debug("Enqueue upsert request in Proxy",
zap.Int("len(FieldsData)", len(request.FieldsData)), zap.Int("len(FieldsData)", len(request.FieldsData)),
zap.Int("len(HashKeys)", len(request.HashKeys))) zap.Int("len(HashKeys)", len(request.HashKeys)))
if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil { if err := node.sched.dmQueue.Enqueue(it); err != nil {
log.Info("Failed to enqueue upsert task", log.Info("Failed to enqueue upsert task",
zap.Error(err)) zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
@ -3561,11 +3529,11 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
defer sp.End() defer sp.End()
ft := &flushTask{ ft := &flushTask{
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
FlushRequest: request, FlushRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream, chMgr: node.chMgr,
} }
method := "Flush" method := "Flush"
@ -3578,16 +3546,7 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
zap.Any("collections", request.CollectionNames)) zap.Any("collections", request.CollectionNames))
log.Debug(rpcReceived(method)) log.Debug(rpcReceived(method))
if err := node.sched.dcQueue.Enqueue(ft); err != nil {
var enqueuedTask task = ft
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &flushTaskByStreamingService{
flushTask: ft,
chMgr: node.chMgr,
}
}
if err := node.sched.dcQueue.Enqueue(enqueuedTask); err != nil {
log.Warn( log.Warn(
rpcFailedToEnqueue(method), rpcFailedToEnqueue(method),
zap.Error(err)) zap.Error(err))
@ -3846,7 +3805,6 @@ func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAlia
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
CreateAliasRequest: request, CreateAliasRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
method := "CreateAlias" method := "CreateAlias"
@ -4034,11 +3992,10 @@ func (node *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasReq
defer sp.End() defer sp.End()
dat := &DropAliasTask{ dat := &DropAliasTask{
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
DropAliasRequest: request, DropAliasRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
method := "DropAlias" method := "DropAlias"
@ -4098,11 +4055,10 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR
defer sp.End() defer sp.End()
aat := &AlterAliasTask{ aat := &AlterAliasTask{
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
AlterAliasRequest: request, AlterAliasRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
} }
method := "AlterAlias" method := "AlterAlias"
@ -4174,11 +4130,11 @@ func (node *Proxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllReque
defer sp.End() defer sp.End()
ft := &flushAllTask{ ft := &flushAllTask{
ctx: ctx, ctx: ctx,
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
FlushAllRequest: request, FlushAllRequest: request,
mixCoord: node.mixCoord, mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream, chMgr: node.chMgr,
} }
method := "FlushAll" method := "FlushAll"
@ -4191,15 +4147,7 @@ func (node *Proxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllReque
log.Debug(rpcReceived(method)) log.Debug(rpcReceived(method))
var enqueuedTask task = ft if err := node.sched.dcQueue.Enqueue(ft); err != nil {
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &flushAllTaskbyStreamingService{
flushAllTask: ft,
chMgr: node.chMgr,
}
}
if err := node.sched.dcQueue.Enqueue(enqueuedTask); err != nil {
log.Warn(rpcFailedToEnqueue(method), zap.Error(err)) log.Warn(rpcFailedToEnqueue(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel, request.GetDbName(), "").Inc() metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel, request.GetDbName(), "").Inc()
resp.Status = merr.Status(err) resp.Status = merr.Status(err)
@ -5198,9 +5146,6 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre
zap.Error(err)) zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err return result, err
} }
@ -5273,9 +5218,6 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
zap.Error(err)) zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err return result, err
} }
@ -5306,9 +5248,6 @@ func (node *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCre
zap.Error(err)) zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err return result, err
} }
@ -5372,9 +5311,6 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque
log.Warn("fail to create role", zap.Error(err)) log.Warn("fail to create role", zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }
@ -5407,9 +5343,6 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest)
zap.Error(err)) zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }
@ -5439,9 +5372,6 @@ func (node *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUse
log.Warn("fail to operate user role", zap.Error(err)) log.Warn("fail to operate user role", zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }
@ -5624,9 +5554,6 @@ func (node *Proxy) OperatePrivilegeV2(ctx context.Context, req *milvuspb.Operate
} }
} }
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }
@ -5675,9 +5602,6 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr
} }
} }
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }
@ -5914,11 +5838,6 @@ func (node *Proxy) RenameCollection(ctx context.Context, req *milvuspb.RenameCol
log.Warn("failed to rename collection", zap.Error(err)) log.Warn("failed to rename collection", zap.Error(err))
return merr.Status(err), err return merr.Status(err), err
} }
if merr.Ok(resp) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return resp, nil return resp, nil
} }
@ -6432,136 +6351,9 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest
} }
func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) { func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.ReplicateMessageResponse{
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil Status: merr.Status(merr.WrapErrServiceUnavailable("not supported in streaming mode")),
} }, nil
var err error
if req.GetChannelName() == "" {
log.Ctx(ctx).Warn("channel name is empty")
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.WrapErrParameterInvalidMsg("invalid channel name for the replicate message request")),
}, nil
}
// get the latest position of the replicate msg channel
replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
if req.GetChannelName() == replicateMsgChannel {
msgID, err := msgstream.GetChannelLatestMsgID(ctx, node.factory, replicateMsgChannel)
if err != nil {
log.Ctx(ctx).Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
position := &msgpb.MsgPosition{
ChannelName: replicateMsgChannel,
MsgID: msgID,
}
positionBytes, err := proto.Marshal(position)
if err != nil {
log.Ctx(ctx).Warn("failed to marshal position", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(nil),
Position: base64.StdEncoding.EncodeToString(positionBytes),
}, nil
}
collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool()
ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool()
// replicate message can be use in two ways, otherwise return error
// 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode
// 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode
if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) {
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.ErrDenyReplicateMessage),
}, nil
}
msgPack := &msgstream.MsgPack{
BeginTs: req.BeginTs,
EndTs: req.EndTs,
Msgs: make([]msgstream.TsMsg, 0),
StartPositions: req.StartPositions,
EndPositions: req.EndPositions,
}
checkCollectionReplicateProperty := func(dbName, collectionName string) bool {
if !collectionReplicateEnable {
return true
}
replicateID, err := GetReplicateID(ctx, dbName, collectionName)
if err != nil {
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
return false
}
return replicateID != ""
}
// getTsMsgFromConsumerMsg
for i, msgBytes := range req.Msgs {
header := commonpb.MsgHeader{}
err = proto.Unmarshal(msgBytes, &header)
if err != nil {
log.Ctx(ctx).Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
if header.GetBase() == nil {
log.Ctx(ctx).Warn("msg header base is nil", zap.Int("index", i))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
}
tsMsg, err := node.replicateStreamManager.GetMsgDispatcher().Unmarshal(msgBytes, header.GetBase().GetMsgType())
if err != nil {
log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
}
switch realMsg := tsMsg.(type) {
case *msgstream.InsertMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
if err != nil {
log.Ctx(ctx).Warn("failed to get segment id", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
if len(assignedSegmentInfos) == 0 {
log.Ctx(ctx).Warn("no segment id assigned")
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrNoAssignSegmentID)}, nil
}
for assignSegmentID := range assignedSegmentInfos {
realMsg.SegmentID = assignSegmentID
break
}
case *msgstream.DeleteMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
}
msgStream, err := node.replicateStreamManager.GetReplicateMsgStream(ctx, req.ChannelName)
if err != nil {
log.Ctx(ctx).Warn("failed to get msg stream from the replicate stream manager", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(err),
}, nil
}
messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
if err != nil {
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
var position string
if len(messageIDsMap[req.GetChannelName()]) == 0 {
log.Ctx(ctx).Warn("no message id returned")
} else {
messageIDs := messageIDsMap[req.GetChannelName()]
position = base64.StdEncoding.EncodeToString(messageIDs[len(messageIDs)-1].Serialize())
}
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil
} }
func (node *Proxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) { func (node *Proxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) {
@ -6821,9 +6613,6 @@ func (node *Proxy) CreatePrivilegeGroup(ctx context.Context, req *milvuspb.Creat
log.Warn("fail to create privilege group", zap.Error(err)) log.Warn("fail to create privilege group", zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }
@ -6853,9 +6642,6 @@ func (node *Proxy) DropPrivilegeGroup(ctx context.Context, req *milvuspb.DropPri
log.Warn("fail to drop privilege group", zap.Error(err)) log.Warn("fail to drop privilege group", zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }
@ -6919,9 +6705,6 @@ func (node *Proxy) OperatePrivilegeGroup(ctx context.Context, req *milvuspb.Oper
log.Warn("fail to operate privilege group", zap.Error(err)) log.Warn("fail to operate privilege group", zap.Error(err))
return merr.Status(err), nil return merr.Status(err), nil
} }
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil return result, nil
} }

View File

@ -18,12 +18,10 @@ package proxy
import ( import (
"context" "context"
"encoding/base64"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/bytedance/mockey" "github.com/bytedance/mockey"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
@ -31,38 +29,27 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/allocator"
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client" grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
"github.com/milvus-io/milvus/internal/distributed/streaming"
mhttp "github.com/milvus-io/milvus/internal/http" mhttp "github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
mqcommon "github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb" "github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil" "github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/v2/util/resource"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -254,7 +241,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
qc.EXPECT().ShowLoadCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().ShowLoadCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
tsoAllocatorIns := newMockTsoAllocator() tsoAllocatorIns := newMockTsoAllocator()
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns, node.factory) node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns)
assert.NoError(t, err) assert.NoError(t, err)
node.sched.Start() node.sched.Start()
defer node.sched.Close() defer node.sched.Close()
@ -336,7 +323,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
Status: merr.Success(), Status: merr.Success(),
}, nil).Maybe() }, nil).Maybe()
tsoAllocatorIns := newMockTsoAllocator() tsoAllocatorIns := newMockTsoAllocator()
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns, node.factory) node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns)
assert.NoError(t, err) assert.NoError(t, err)
node.sched.Start() node.sched.Start()
defer node.sched.Close() defer node.sched.Close()
@ -393,11 +380,7 @@ func createTestProxy() *Proxy {
tso: newMockTimestampAllocatorInterface(), tso: newMockTimestampAllocatorInterface(),
} }
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() node.sched, _ = newTaskScheduler(ctx, node.tsoAllocator)
node.replicateMsgStream, _ = node.factory.NewMsgStream(node.ctx)
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
node.sched, _ = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched.Start() node.sched.Start()
return node return node
@ -418,10 +401,12 @@ func TestProxy_FlushAll_NoDatabase(t *testing.T) {
mockey.Mock(paramtable.Init).Return().Build() mockey.Mock(paramtable.Init).Return().Build()
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build() mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
// Mock grpc mix coord client FlushAll method
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) { mockey.Mock((*grpcmixcoordclient.Client).ListDatabases).To(func(ctx context.Context, req *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) {
return &datapb.FlushAllResponse{Status: successStatus}, nil return &milvuspb.ListDatabasesResponse{Status: successStatus}, nil
}).Build()
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
return &milvuspb.ShowCollectionsResponse{Status: successStatus}, nil
}).Build() }).Build()
// Act: Execute test // Act: Execute test
@ -454,10 +439,13 @@ func TestProxy_FlushAll_WithDefaultDatabase(t *testing.T) {
mockey.Mock(paramtable.Init).Return().Build() mockey.Mock(paramtable.Init).Return().Build()
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build() mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
// Mock grpc mix coord client FlushAll method for default database
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) { mockey.Mock((*grpcmixcoordclient.Client).ListDatabases).To(func(ctx context.Context, req *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) {
return &datapb.FlushAllResponse{Status: successStatus}, nil return &milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil
}).Build()
// Mock grpc mix coord client FlushAll method for default database
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
return &milvuspb.ShowCollectionsResponse{Status: successStatus}, nil
}).Build() }).Build()
// Act: Execute test // Act: Execute test
@ -490,9 +478,8 @@ func TestProxy_FlushAll_DatabaseNotExist(t *testing.T) {
mockey.Mock(paramtable.Init).Return().Build() mockey.Mock(paramtable.Init).Return().Build()
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build() mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
// Mock grpc mix coord client FlushAll method for non-existent database mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) { return &milvuspb.ShowCollectionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil
return &datapb.FlushAllResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil
}).Build() }).Build()
// Act: Execute test // Act: Execute test
@ -889,18 +876,13 @@ func TestProxyCreateDatabase(t *testing.T) {
} }
node.simpleLimiter = NewSimpleLimiter(0, 0) node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy) node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10) node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err) assert.NoError(t, err)
err = node.sched.Start() err = node.sched.Start()
assert.NoError(t, err) assert.NoError(t, err)
defer node.sched.Close() defer node.sched.Close()
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("create database fail", func(t *testing.T) { t.Run("create database fail", func(t *testing.T) {
mixc := mocks.NewMockMixCoordClient(t) mixc := mocks.NewMockMixCoordClient(t)
mixc.On("CreateDatabase", mock.Anything, mock.Anything). mixc.On("CreateDatabase", mock.Anything, mock.Anything).
@ -949,18 +931,13 @@ func TestProxyDropDatabase(t *testing.T) {
} }
node.simpleLimiter = NewSimpleLimiter(0, 0) node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy) node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10) node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err) assert.NoError(t, err)
err = node.sched.Start() err = node.sched.Start()
assert.NoError(t, err) assert.NoError(t, err)
defer node.sched.Close() defer node.sched.Close()
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("drop database fail", func(t *testing.T) { t.Run("drop database fail", func(t *testing.T) {
mixc := mocks.NewMockMixCoordClient(t) mixc := mocks.NewMockMixCoordClient(t)
mixc.On("DropDatabase", mock.Anything, mock.Anything). mixc.On("DropDatabase", mock.Anything, mock.Anything).
@ -1007,7 +984,7 @@ func TestProxyListDatabase(t *testing.T) {
} }
node.simpleLimiter = NewSimpleLimiter(0, 0) node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy) node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10) node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err) assert.NoError(t, err)
err = node.sched.Start() err = node.sched.Start()
@ -1063,7 +1040,7 @@ func TestProxyAlterDatabase(t *testing.T) {
} }
node.simpleLimiter = NewSimpleLimiter(0, 0) node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy) node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10) node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err) assert.NoError(t, err)
err = node.sched.Start() err = node.sched.Start()
@ -1116,7 +1093,7 @@ func TestProxyDescribeDatabase(t *testing.T) {
} }
node.simpleLimiter = NewSimpleLimiter(0, 0) node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy) node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory) node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10) node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err) assert.NoError(t, err)
err = node.sched.Start() err = node.sched.Start()
@ -1347,7 +1324,7 @@ func TestProxy_Delete(t *testing.T) {
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0) idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
assert.NoError(t, err) assert.NoError(t, err)
queue, err := newTaskScheduler(ctx, tsoAllocator, nil) queue, err := newTaskScheduler(ctx, tsoAllocator)
assert.NoError(t, err) assert.NoError(t, err)
node := &Proxy{chMgr: chMgr, rowIDAllocator: idAllocator, sched: queue} node := &Proxy{chMgr: chMgr, rowIDAllocator: idAllocator, sched: queue}
@ -1358,287 +1335,7 @@ func TestProxy_Delete(t *testing.T) {
}) })
} }
func TestProxy_ReplicateMessage(t *testing.T) {
paramtable.Init()
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
t.Run("proxy unhealthy", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err := node.ReplicateMessage(context.TODO(), nil)
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("not backup instance", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err := node.ReplicateMessage(context.TODO(), nil)
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("empty channel name", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
resp, err := node.ReplicateMessage(context.TODO(), nil)
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("fail to get msg stream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock error: get msg stream")
}
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
node := &Proxy{
replicateStreamManager: manager,
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ChannelName: "unit_test_replicate_message"})
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("get latest position", func(t *testing.T) {
base64DecodeMsgPosition := func(position string) (*msgstream.MsgPosition, error) {
decodeBytes, err := base64.StdEncoding.DecodeString(position)
if err != nil {
log.Warn("fail to decode the position", zap.Error(err))
return nil, err
}
msgPosition := &msgstream.MsgPosition{}
err = proto.Unmarshal(decodeBytes, msgPosition)
if err != nil {
log.Warn("fail to unmarshal the position", zap.Error(err))
return nil, err
}
return msgPosition, nil
}
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
factory := dependency.NewMockFactory(t)
stream := msgstream.NewMockMsgStream(t)
mockMsgID := mqcommon.NewMockMessageID(t)
factory.EXPECT().NewMsgStream(mock.Anything).Return(stream, nil).Once()
mockMsgID.EXPECT().Serialize().Return([]byte("mock")).Once()
stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
stream.EXPECT().GetLatestMsgID(mock.Anything).Return(mockMsgID, nil).Once()
stream.EXPECT().Close().Return()
node := &Proxy{
factory: factory,
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(),
})
assert.NoError(t, err)
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
{
p, err := base64DecodeMsgPosition(resp.GetPosition())
assert.NoError(t, err)
assert.Equal(t, []byte("mock"), p.MsgID)
}
factory.EXPECT().NewMsgStream(mock.Anything).Return(nil, errors.New("mock")).Once()
resp, err = node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(),
})
assert.NoError(t, err)
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
})
t.Run("invalid msg pack", func(t *testing.T) {
node := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
{
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
Msgs: [][]byte{{1, 2, 3}},
})
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
}
{
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 1,
EndTimestamp: 10,
HashValues: []uint32{0},
},
TimeTickMsg: &msgpb.TimeTickMsg{},
}
msgBytes, _ := timeTickMsg.Marshal(timeTickMsg)
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
Msgs: [][]byte{msgBytes.([]byte)},
})
assert.NoError(t, err)
log.Info("resp", zap.Any("resp", resp))
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
}
{
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 1,
EndTimestamp: 10,
HashValues: []uint32{0},
},
TimeTickMsg: &msgpb.TimeTickMsg{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType(-1)),
commonpbutil.WithTimeStamp(10),
commonpbutil.WithSourceID(-1),
),
},
}
msgBytes, _ := timeTickMsg.Marshal(timeTickMsg)
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
Msgs: [][]byte{msgBytes.([]byte)},
})
assert.NoError(t, err)
log.Info("resp", zap.Any("resp", resp))
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
}
})
t.Run("success", func(t *testing.T) {
paramtable.Init()
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return()
msgStreamObj.EXPECT().Close().Return()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2"))
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {mockMsgID1, mockMsgID2},
}, nil)
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return msgStreamObj, nil
}
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(1000)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick1)
assert.NoError(t, err)
segAllocator.Start()
node := &Proxy{
replicateStreamManager: manager,
segAssigner: segAllocator,
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 4,
EndTimestamp: 10,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "unit_test_replicate_message",
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: 10,
SourceID: -1,
},
ShardName: "unit_test_replicate_message_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{10},
RowIDs: []int64{66},
NumRows: 1,
},
}
msgBytes, _ := insertMsg.Marshal(insertMsg)
replicateRequest := &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
BeginTs: 1,
EndTs: 10,
Msgs: [][]byte{msgBytes.([]byte)},
StartPositions: []*msgpb.MsgPosition{
{ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 1")},
},
EndPositions: []*msgpb.MsgPosition{
{ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 2")},
},
}
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock message id 2")), resp.GetPosition())
res := resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
assert.NotNil(t, res)
time.Sleep(2 * time.Second)
{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast"))
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
time.Sleep(2 * time.Second)
}
{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {},
}, nil)
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
assert.Empty(t, resp.GetPosition())
resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
time.Sleep(2 * time.Second)
broadcastMock.Unset()
}
})
}
func TestProxy_ImportV2(t *testing.T) { func TestProxy_ImportV2(t *testing.T) {
wal := mock_streaming.NewMockWALAccesser(t)
b := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(b).Maybe()
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil).Maybe()
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
ctx := context.Background() ctx := context.Background()
mockErr := errors.New("mock error") mockErr := errors.New("mock error")
@ -1661,7 +1358,7 @@ func TestProxy_ImportV2(t *testing.T) {
node.tsoAllocator = &timestampAllocator{ node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(), tso: newMockTimestampAllocatorInterface(),
} }
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory) scheduler, err := newTaskScheduler(ctx, node.tsoAllocator)
assert.NoError(t, err) assert.NoError(t, err)
node.sched = scheduler node.sched = scheduler
err = node.sched.Start() err = node.sched.Start()
@ -1916,333 +1613,6 @@ func TestRegisterRestRouter(t *testing.T) {
} }
} }
func TestReplicateMessageForCollectionMode(t *testing.T) {
paramtable.Init()
ctx := context.Background()
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 10,
EndTimestamp: 10,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: 10,
SourceID: -1,
},
ShardName: "foo_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{10},
RowIDs: []int64{66},
NumRows: 1,
},
}
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 20,
EndTimestamp: 20,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 10002,
Timestamp: 20,
SourceID: -1,
},
ShardName: "foo_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
},
}
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
t.Run("replicate message in the replicate collection mode", func(t *testing.T) {
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
{
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "false")
p := &Proxy{}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
})
assert.NoError(t, err)
assert.Error(t, merr.Error(r.Status))
}
{
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
p := &Proxy{}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
})
assert.NoError(t, err)
assert.Error(t, merr.Error(r.Status))
}
})
t.Run("replicate message for the replicate collection mode", func(t *testing.T) {
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Twice()
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Twice()
globalMetaCache = mockCache
{
p := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
Msgs: [][]byte{insertMsgBytes.([]byte)},
})
assert.NoError(t, err)
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
}
{
p := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
Msgs: [][]byte{deleteMsgBytes.([]byte)},
})
assert.NoError(t, err)
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
}
})
}
func TestAlterCollectionReplicateProperty(t *testing.T) {
paramtable.Init()
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
replicateID: "local-milvus",
}, nil).Maybe()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe()
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
globalMetaCache = mockCache
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().Close().Return().Maybe()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe()
msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"alter_property": {mockMsgID1, mockMsgID2},
}, nil).Maybe()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return msgStreamObj, nil
}
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
ctx := context.Background()
var startTt uint64 = 10
startTime := time.Now()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(1000)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, func() Timestamp {
return Timestamp(time.Since(startTime).Seconds()) + startTt
})
assert.NoError(t, err)
segAllocator.Start()
mockMixcoord := mocks.NewMockMixCoordClient(t)
mockMixcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *rootcoordpb.AllocTimestampRequest, option ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
return &rootcoordpb.AllocTimestampResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Timestamp: Timestamp(time.Since(startTime).Seconds()) + startTt,
}, nil
})
mockMixcoord.EXPECT().AlterCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil)
p := &Proxy{
ctx: ctx,
replicateStreamManager: manager,
segAssigner: segAllocator,
mixCoord: mockMixcoord,
}
tsoAllocatorIns := newMockTsoAllocator()
p.sched, err = newTaskScheduler(p.ctx, tsoAllocatorIns, p.factory)
assert.NoError(t, err)
p.sched.Start()
defer p.sched.Close()
p.UpdateStateCode(commonpb.StateCode_Healthy)
getInsertMsgBytes := func(channel string, ts uint64) []byte {
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: channel,
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: ts,
SourceID: -1,
},
ShardName: channel + "_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{ts},
RowIDs: []int64{66},
NumRows: 1,
},
}
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
return insertMsgBytes.([]byte)
}
getDeleteMsgBytes := func(channel string, ts uint64) []byte {
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 10002,
Timestamp: ts,
SourceID: -1,
},
ShardName: channel + "_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
},
}
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
return deleteMsgBytes.([]byte)
}
go func() {
// replicate message
var replicateResp *milvuspb.ReplicateMessageResponse
var err error
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_1",
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+5)},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_2",
Msgs: [][]byte{getDeleteMsgBytes("alter_property_2", startTt+5)},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
time.Sleep(time.Second)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_1",
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+10)},
})
assert.NoError(t, err)
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_2",
Msgs: [][]byte{getInsertMsgBytes("alter_property_2", startTt+10)},
})
assert.NoError(t, err)
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
}()
time.Sleep(200 * time.Millisecond)
// alter collection property
statusResp, err := p.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: "default",
CollectionName: "foo_collection",
Properties: []*commonpb.KeyValuePair{
{
Key: "replicate.endTS",
Value: "1",
},
},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(statusResp))
}
func TestRunAnalyzer(t *testing.T) { func TestRunAnalyzer(t *testing.T) {
paramtable.Init() paramtable.Init()
ctx := context.Background() ctx := context.Background()
@ -2254,7 +1624,7 @@ func TestRunAnalyzer(t *testing.T) {
p := &Proxy{} p := &Proxy{}
tsoAllocatorIns := newMockTsoAllocator() tsoAllocatorIns := newMockTsoAllocator()
sched, err := newTaskScheduler(ctx, tsoAllocatorIns, p.factory) sched, err := newTaskScheduler(ctx, tsoAllocatorIns)
require.NoError(t, err) require.NoError(t, err)
sched.Start() sched.Start()
defer sched.Close() defer sched.Close()

View File

@ -2,12 +2,7 @@
package proxy package proxy
import ( import mock "github.com/stretchr/testify/mock"
context "context"
msgstream "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
mock "github.com/stretchr/testify/mock"
)
// MockChannelsMgr is an autogenerated mock type for the channelsMgr type // MockChannelsMgr is an autogenerated mock type for the channelsMgr type
type MockChannelsMgr struct { type MockChannelsMgr struct {
@ -80,65 +75,6 @@ func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]stri
return _c return _c
} }
// getOrCreateDmlStream provides a mock function with given fields: ctx, collectionID
func (_m *MockChannelsMgr) getOrCreateDmlStream(ctx context.Context, collectionID int64) (msgstream.MsgStream, error) {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for getOrCreateDmlStream")
}
var r0 msgstream.MsgStream
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (msgstream.MsgStream, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, int64) msgstream.MsgStream); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(msgstream.MsgStream)
}
}
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockChannelsMgr_getOrCreateDmlStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getOrCreateDmlStream'
type MockChannelsMgr_getOrCreateDmlStream_Call struct {
*mock.Call
}
// getOrCreateDmlStream is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(ctx interface{}, collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", ctx, collectionID)}
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(ctx context.Context, collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStream, _a1 error) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(context.Context, int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Return(run)
return _c
}
// getVChannels provides a mock function with given fields: collectionID // getVChannels provides a mock function with given fields: collectionID
func (_m *MockChannelsMgr) getVChannels(collectionID int64) ([]string, error) { func (_m *MockChannelsMgr) getVChannels(collectionID int64) ([]string, error) {
ret := _m.Called(collectionID) ret := _m.Called(collectionID)
@ -197,38 +133,6 @@ func (_c *MockChannelsMgr_getVChannels_Call) RunAndReturn(run func(int64) ([]str
return _c return _c
} }
// removeAllDMLStream provides a mock function with no fields
func (_m *MockChannelsMgr) removeAllDMLStream() {
_m.Called()
}
// MockChannelsMgr_removeAllDMLStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeAllDMLStream'
type MockChannelsMgr_removeAllDMLStream_Call struct {
*mock.Call
}
// removeAllDMLStream is a helper method to define mock.On call
func (_e *MockChannelsMgr_Expecter) removeAllDMLStream() *MockChannelsMgr_removeAllDMLStream_Call {
return &MockChannelsMgr_removeAllDMLStream_Call{Call: _e.mock.On("removeAllDMLStream")}
}
func (_c *MockChannelsMgr_removeAllDMLStream_Call) Run(run func()) *MockChannelsMgr_removeAllDMLStream_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockChannelsMgr_removeAllDMLStream_Call) Return() *MockChannelsMgr_removeAllDMLStream_Call {
_c.Call.Return()
return _c
}
func (_c *MockChannelsMgr_removeAllDMLStream_Call) RunAndReturn(run func()) *MockChannelsMgr_removeAllDMLStream_Call {
_c.Run(run)
return _c
}
// removeDMLStream provides a mock function with given fields: collectionID // removeDMLStream provides a mock function with given fields: collectionID
func (_m *MockChannelsMgr) removeDMLStream(collectionID int64) { func (_m *MockChannelsMgr) removeDMLStream(collectionID int64) {
_m.Called(collectionID) _m.Called(collectionID)

View File

@ -18,23 +18,12 @@ package proxy
import ( import (
"context" "context"
"strconv"
"time"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -103,197 +92,3 @@ func genInsertMsgsByPartition(ctx context.Context,
return repackedMsgs, nil return repackedMsgs, nil
} }
func repackInsertDataByPartition(ctx context.Context,
partitionName string,
rowOffsets []int,
channelName string,
insertMsg *msgstream.InsertMsg,
segIDAssigner *segIDAssigner,
) ([]msgstream.TsMsg, error) {
res := make([]msgstream.TsMsg, 0)
maxTs := Timestamp(0)
for _, offset := range rowOffsets {
ts := insertMsg.Timestamps[offset]
if maxTs < ts {
maxTs = ts
}
}
partitionID, err := globalMetaCache.GetPartitionID(ctx, insertMsg.GetDbName(), insertMsg.CollectionName, partitionName)
if err != nil {
return nil, err
}
beforeAssign := time.Now()
assignedSegmentInfos, err := segIDAssigner.GetSegmentID(insertMsg.CollectionID, partitionID, channelName, uint32(len(rowOffsets)), maxTs)
metrics.ProxyAssignSegmentIDLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(time.Since(beforeAssign).Milliseconds()))
if err != nil {
log.Ctx(ctx).Error("allocate segmentID for insert data failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("channelName", channelName),
zap.Int("allocate count", len(rowOffsets)),
zap.Error(err))
return nil, err
}
startPos := 0
for segmentID, count := range assignedSegmentInfos {
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
msgs, err := genInsertMsgsByPartition(ctx, segmentID, partitionID, partitionName, subRowOffsets, channelName, insertMsg)
if err != nil {
log.Ctx(ctx).Warn("repack insert data to insert msgs failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.Int64("partitionID", partitionID),
zap.Error(err))
return nil, err
}
res = append(res, msgs...)
startPos += int(count)
}
return res, nil
}
func setMsgID(ctx context.Context,
msgs []msgstream.TsMsg,
idAllocator *allocator.IDAllocator,
) error {
var idBegin int64
var err error
err = retry.Do(ctx, func() error {
idBegin, _, err = idAllocator.Alloc(uint32(len(msgs)))
return err
})
if err != nil {
log.Ctx(ctx).Error("failed to allocate msg id", zap.Error(err))
return err
}
for i, msg := range msgs {
msg.SetID(idBegin + UniqueID(i))
}
return nil
}
func repackInsertData(ctx context.Context,
channelNames []string,
insertMsg *msgstream.InsertMsg,
result *milvuspb.MutationResult,
idAllocator *allocator.IDAllocator,
segIDAssigner *segIDAssigner,
) (*msgstream.MsgPack, error) {
msgPack := &msgstream.MsgPack{
BeginTs: insertMsg.BeginTs(),
EndTs: insertMsg.EndTs(),
}
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
for channel, rowOffsets := range channel2RowOffsets {
partitionName := insertMsg.PartitionName
msgs, err := repackInsertDataByPartition(ctx, partitionName, rowOffsets, channel, insertMsg, segIDAssigner)
if err != nil {
log.Ctx(ctx).Warn("repack insert data to msg pack failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("partition name", partitionName),
zap.Error(err))
return nil, err
}
msgPack.Msgs = append(msgPack.Msgs, msgs...)
}
err := setMsgID(ctx, msgPack.Msgs, idAllocator)
if err != nil {
log.Ctx(ctx).Error("failed to set msgID when repack insert data",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("partition name", insertMsg.PartitionName),
zap.Error(err))
return nil, err
}
return msgPack, nil
}
func repackInsertDataWithPartitionKey(ctx context.Context,
channelNames []string,
partitionKeys *schemapb.FieldData,
insertMsg *msgstream.InsertMsg,
result *milvuspb.MutationResult,
idAllocator *allocator.IDAllocator,
segIDAssigner *segIDAssigner,
) (*msgstream.MsgPack, error) {
msgPack := &msgstream.MsgPack{
BeginTs: insertMsg.BeginTs(),
EndTs: insertMsg.EndTs(),
}
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, insertMsg.GetDbName(), insertMsg.CollectionName)
if err != nil {
log.Ctx(ctx).Warn("get default partition names failed in partition key mode",
zap.String("collectionName", insertMsg.CollectionName),
zap.Error(err))
return nil, err
}
hashValues, err := typeutil.HashKey2Partitions(partitionKeys, partitionNames)
if err != nil {
log.Ctx(ctx).Warn("has partition keys to partitions failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.Error(err))
return nil, err
}
for channel, rowOffsets := range channel2RowOffsets {
partition2RowOffsets := make(map[string][]int)
for _, idx := range rowOffsets {
partitionName := partitionNames[hashValues[idx]]
if _, ok := partition2RowOffsets[partitionName]; !ok {
partition2RowOffsets[partitionName] = []int{}
}
partition2RowOffsets[partitionName] = append(partition2RowOffsets[partitionName], idx)
}
errGroup, _ := errgroup.WithContext(ctx)
partition2Msgs := typeutil.NewConcurrentMap[string, []msgstream.TsMsg]()
for partitionName, offsets := range partition2RowOffsets {
partitionName := partitionName
offsets := offsets
errGroup.Go(func() error {
msgs, err := repackInsertDataByPartition(ctx, partitionName, offsets, channel, insertMsg, segIDAssigner)
if err != nil {
return err
}
partition2Msgs.Insert(partitionName, msgs)
return nil
})
}
err = errGroup.Wait()
if err != nil {
log.Ctx(ctx).Warn("repack insert data into insert msg pack failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("channelName", channel),
zap.Error(err))
return nil, err
}
partition2Msgs.Range(func(name string, msgs []msgstream.TsMsg) bool {
msgPack.Msgs = append(msgPack.Msgs, msgs...)
return true
})
}
err = setMsgID(ctx, msgPack.Msgs, idAllocator)
if err != nil {
log.Ctx(ctx).Error("failed to set msgID when repack insert data",
zap.String("collectionName", insertMsg.CollectionName),
zap.Error(err))
return nil, err
}
return msgPack, nil
}

View File

@ -21,7 +21,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -49,12 +48,6 @@ func TestRepackInsertData(t *testing.T) {
defer mix.Close() defer mix.Close()
cache := NewMockCache(t) cache := NewMockCache(t)
cache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(int64(1), nil)
globalMetaCache = cache globalMetaCache = cache
idAllocator, err := allocator.NewIDAllocator(ctx, mix, paramtable.GetNodeID()) idAllocator, err := allocator.NewIDAllocator(ctx, mix, paramtable.GetNodeID())
@ -113,33 +106,6 @@ func TestRepackInsertData(t *testing.T) {
for index := range insertMsg.RowIDs { for index := range insertMsg.RowIDs {
insertMsg.RowIDs[index] = int64(index) insertMsg.RowIDs[index] = int64(index)
} }
ids, err := parsePrimaryFieldData2IDs(fieldData)
assert.NoError(t, err)
result := &milvuspb.MutationResult{
IDs: ids,
}
t.Run("assign segmentID failed", func(t *testing.T) {
fakeSegAllocator, err := newSegIDAssigner(ctx, &mockDataCoord2{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = fakeSegAllocator.Start()
defer fakeSegAllocator.Close()
_, err = repackInsertData(ctx, []string{"test_dml_channel"}, insertMsg,
result, idAllocator, fakeSegAllocator)
assert.Error(t, err)
})
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("repack insert data success", func(t *testing.T) {
_, err = repackInsertData(ctx, []string{"test_dml_channel"}, insertMsg, result, idAllocator, segAllocator)
assert.NoError(t, err)
})
} }
func TestRepackInsertDataWithPartitionKey(t *testing.T) { func TestRepackInsertDataWithPartitionKey(t *testing.T) {
@ -161,11 +127,6 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
_ = idAllocator.Start() _ = idAllocator.Start()
defer idAllocator.Close() defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = segAllocator.Start()
defer segAllocator.Close()
fieldName2Types := map[string]schemapb.DataType{ fieldName2Types := map[string]schemapb.DataType{
testInt64Field: schemapb.DataType_Int64, testInt64Field: schemapb.DataType_Int64,
testVarCharField: schemapb.DataType_VarChar, testVarCharField: schemapb.DataType_VarChar,
@ -221,17 +182,4 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
for index := range insertMsg.RowIDs { for index := range insertMsg.RowIDs {
insertMsg.RowIDs[index] = int64(index) insertMsg.RowIDs[index] = int64(index)
} }
ids, err := parsePrimaryFieldData2IDs(fieldNameToDatas[testInt64Field])
assert.NoError(t, err)
result := &milvuspb.MutationResult{
IDs: ids,
}
t.Run("repack insert data success", func(t *testing.T) {
partitionKeys := generateFieldData(schemapb.DataType_VarChar, testVarCharField, nb)
_, err = repackInsertDataWithPartitionKey(ctx, []string{"test_dml_channel"}, partitionKeys,
insertMsg, result, idAllocator, segAllocator)
assert.NoError(t, err)
})
} }

View File

@ -21,14 +21,12 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"os" "os"
"strconv"
"sync" "sync"
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hashicorp/golang-lru/v2/expirable" "github.com/hashicorp/golang-lru/v2/expirable"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
@ -40,19 +38,15 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/hookutil" "github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/expr" "github.com/milvus-io/milvus/pkg/v2/util/expr"
"github.com/milvus-io/milvus/pkg/v2/util/logutil" "github.com/milvus-io/milvus/pkg/v2/util/logutil"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil" "github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/v2/util/resource" "github.com/milvus-io/milvus/pkg/v2/util/resource"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -87,7 +81,6 @@ type Proxy struct {
stateCode atomic.Int32 stateCode atomic.Int32
etcdCli *clientv3.Client
address string address string
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
@ -95,23 +88,16 @@ type Proxy struct {
chMgr channelsMgr chMgr channelsMgr
replicateMsgStream msgstream.MsgStream
sched *taskScheduler sched *taskScheduler
chTicker channelsTimeTicker
rowIDAllocator *allocator.IDAllocator rowIDAllocator *allocator.IDAllocator
tsoAllocator *timestampAllocator tsoAllocator *timestampAllocator
segAssigner *segIDAssigner
metricsCacheManager *metricsinfo.MetricsCacheManager metricsCacheManager *metricsinfo.MetricsCacheManager
session *sessionutil.Session session *sessionutil.Session
shardMgr shardClientMgr shardMgr shardClientMgr
factory dependency.Factory
searchResultCh chan *internalpb.SearchResults searchResultCh chan *internalpb.SearchResults
// Add callback functions at different stages // Add callback functions at different stages
@ -122,8 +108,7 @@ type Proxy struct {
lbPolicy LBPolicy lbPolicy LBPolicy
// resource manager // resource manager
resourceManager resource.Manager resourceManager resource.Manager
replicateStreamManager *ReplicateStreamManager
// materialized view // materialized view
enableMaterializedView bool enableMaterializedView bool
@ -135,7 +120,7 @@ type Proxy struct {
} }
// NewProxy returns a Proxy struct. // NewProxy returns a Proxy struct.
func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { func NewProxy(ctx context.Context, _ dependency.Factory) (*Proxy, error) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
ctx1, cancel := context.WithCancel(ctx) ctx1, cancel := context.WithCancel(ctx)
n := 1024 // better to be configurable n := 1024 // better to be configurable
@ -143,18 +128,15 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
lbPolicy := NewLBPolicyImpl(mgr) lbPolicy := NewLBPolicyImpl(mgr)
lbPolicy.Start(ctx) lbPolicy.Start(ctx)
resourceManager := resource.NewManager(10*time.Second, 20*time.Second, make(map[string]time.Duration)) resourceManager := resource.NewManager(10*time.Second, 20*time.Second, make(map[string]time.Duration))
replicateStreamManager := NewReplicateStreamManager(ctx, factory, resourceManager)
node := &Proxy{ node := &Proxy{
ctx: ctx1, ctx: ctx1,
cancel: cancel, cancel: cancel,
factory: factory, searchResultCh: make(chan *internalpb.SearchResults, n),
searchResultCh: make(chan *internalpb.SearchResults, n), shardMgr: mgr,
shardMgr: mgr, simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()), lbPolicy: lbPolicy,
lbPolicy: lbPolicy, resourceManager: resourceManager,
resourceManager: resourceManager, slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
replicateStreamManager: replicateStreamManager,
slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
} }
node.UpdateStateCode(commonpb.StateCode_Abnormal) node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node) expr.Register("proxy", node)
@ -223,10 +205,6 @@ func (node *Proxy) Init() error {
} }
log.Info("init session for Proxy done") log.Info("init session for Proxy done")
node.factory.Init(Params)
log.Debug("init access log for Proxy done")
err := node.initRateCollector() err := node.initRateCollector()
if err != nil { if err != nil {
return err return err
@ -253,44 +231,18 @@ func (node *Proxy) Init() error {
node.tsoAllocator = tsoAllocator node.tsoAllocator = tsoAllocator
log.Debug("create timestamp allocator done", zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID())) log.Debug("create timestamp allocator done", zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()))
segAssigner, err := newSegIDAssigner(node.ctx, node.mixCoord, node.lastTick)
if err != nil {
log.Warn("failed to create segment id assigner",
zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()),
zap.Error(err))
return err
}
node.segAssigner = segAssigner
node.segAssigner.PeerID = paramtable.GetNodeID()
log.Debug("create segment id assigner done", zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()))
dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.mixCoord) dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.mixCoord)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, node.factory) chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc)
node.chMgr = chMgr node.chMgr = chMgr
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole)) log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))
replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue() node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator)
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
if err != nil {
log.Warn("failed to create replicate msg stream",
zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()),
zap.Error(err))
return err
}
node.replicateMsgStream.ForceEnableProduce(true)
node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel})
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory)
if err != nil { if err != nil {
log.Warn("failed to create task scheduler", zap.String("role", typeutil.ProxyRole), zap.Error(err)) log.Warn("failed to create task scheduler", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err return err
} }
log.Debug("create task scheduler done", zap.String("role", typeutil.ProxyRole)) log.Debug("create task scheduler done", zap.String("role", typeutil.ProxyRole))
syncTimeTickInterval := Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond) / 2
node.chTicker = newChannelsTimeTicker(node.ctx, Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond)/2, []string{}, node.sched.getPChanStatistics, tsoAllocator)
log.Debug("create channels time ticker done", zap.String("role", typeutil.ProxyRole), zap.Duration("syncTimeTickInterval", syncTimeTickInterval))
node.enableComplexDeleteLimit = Params.QuotaConfig.ComplexDeleteLimitEnable.GetAsBool() node.enableComplexDeleteLimit = Params.QuotaConfig.ComplexDeleteLimitEnable.GetAsBool()
node.metricsCacheManager = metricsinfo.NewMetricsCacheManager() node.metricsCacheManager = metricsinfo.NewMetricsCacheManager()
log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole)) log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole))
@ -314,90 +266,6 @@ func (node *Proxy) Init() error {
return nil return nil
} }
// sendChannelsTimeTickLoop starts a goroutine that synchronizes the time tick information.
func (node *Proxy) sendChannelsTimeTickLoop() {
log := log.Ctx(node.ctx)
node.wg.Add(1)
go func() {
defer node.wg.Done()
ticker := time.NewTicker(Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond))
defer ticker.Stop()
for {
select {
case <-node.ctx.Done():
log.Info("send channels time tick loop exit")
return
case <-ticker.C:
if !Params.CommonCfg.TTMsgEnabled.GetAsBool() {
continue
}
stats, ts, err := node.chTicker.getMinTsStatistics()
if err != nil {
log.Warn("sendChannelsTimeTickLoop.getMinTsStatistics", zap.Error(err))
continue
}
if ts == 0 {
log.Warn("sendChannelsTimeTickLoop.getMinTsStatistics default timestamp equal 0")
continue
}
channels := make([]pChan, 0, len(stats))
tss := make([]Timestamp, 0, len(stats))
maxTs := ts
for channel, ts := range stats {
channels = append(channels, channel)
tss = append(tss, ts)
if ts > maxTs {
maxTs = ts
}
}
req := &internalpb.ChannelTimeTickMsg{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_TimeTick),
commonpbutil.WithSourceID(node.session.ServerID),
),
ChannelNames: channels,
Timestamps: tss,
DefaultTimestamp: maxTs,
}
func() {
// we should pay more attention to the max lag.
minTs := maxTs
minTsOfChannel := "default"
// find the min ts and the related channel.
for channel, ts := range stats {
if ts < minTs {
minTs = ts
minTsOfChannel = channel
}
}
sub := tsoutil.SubByNow(minTs)
metrics.ProxySyncTimeTickLag.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), minTsOfChannel).Set(float64(sub))
}()
status, err := node.mixCoord.UpdateChannelTimeTick(node.ctx, req)
if err != nil {
log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick", zap.Error(err))
continue
}
if status.GetErrorCode() != 0 {
log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick",
zap.Any("ErrorCode", status.ErrorCode),
zap.Any("Reason", status.Reason))
continue
}
}
}
}()
}
// Start starts a proxy node. // Start starts a proxy node.
func (node *Proxy) Start() error { func (node *Proxy) Start() error {
log := log.Ctx(node.ctx) log := log.Ctx(node.ctx)
@ -417,22 +285,6 @@ func (node *Proxy) Start() error {
} }
log.Debug("start id allocator done", zap.String("role", typeutil.ProxyRole)) log.Debug("start id allocator done", zap.String("role", typeutil.ProxyRole))
if !streamingutil.IsStreamingServiceEnabled() {
if err := node.segAssigner.Start(); err != nil {
log.Warn("failed to start segment id assigner", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err
}
log.Debug("start segment id assigner done", zap.String("role", typeutil.ProxyRole))
if err := node.chTicker.start(); err != nil {
log.Warn("failed to start channels time ticker", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err
}
log.Debug("start channels time ticker done", zap.String("role", typeutil.ProxyRole))
node.sendChannelsTimeTickLoop()
}
// Start callbacks // Start callbacks
for _, cb := range node.startCallbacks { for _, cb := range node.startCallbacks {
cb() cb()
@ -465,21 +317,6 @@ func (node *Proxy) Stop() error {
log.Info("close scheduler", zap.String("role", typeutil.ProxyRole)) log.Info("close scheduler", zap.String("role", typeutil.ProxyRole))
} }
if !streamingutil.IsStreamingServiceEnabled() {
if node.segAssigner != nil {
node.segAssigner.Close()
log.Info("close segment id assigner", zap.String("role", typeutil.ProxyRole))
}
if node.chTicker != nil {
err := node.chTicker.close()
if err != nil {
return err
}
log.Info("close channels time ticker", zap.String("role", typeutil.ProxyRole))
}
}
for _, cb := range node.closeCallbacks { for _, cb := range node.closeCallbacks {
cb() cb()
} }
@ -492,10 +329,6 @@ func (node *Proxy) Stop() error {
node.shardMgr.Close() node.shardMgr.Close()
} }
if node.chMgr != nil {
node.chMgr.removeAllDMLStream()
}
if node.lbPolicy != nil { if node.lbPolicy != nil {
node.lbPolicy.Close() node.lbPolicy.Close()
} }
@ -519,11 +352,6 @@ func (node *Proxy) AddStartCallback(callbacks ...func()) {
node.startCallbacks = append(node.startCallbacks, callbacks...) node.startCallbacks = append(node.startCallbacks, callbacks...)
} }
// lastTick returns the last write timestamp of all pchans in this Proxy.
func (node *Proxy) lastTick() Timestamp {
return node.chTicker.getMinTick()
}
// AddCloseCallback adds a callback in the Close phase. // AddCloseCallback adds a callback in the Close phase.
func (node *Proxy) AddCloseCallback(callbacks ...func()) { func (node *Proxy) AddCloseCallback(callbacks ...func()) {
node.closeCallbacks = append(node.closeCallbacks, callbacks...) node.closeCallbacks = append(node.closeCallbacks, callbacks...)
@ -537,11 +365,6 @@ func (node *Proxy) GetAddress() string {
return node.address return node.address
} }
// SetEtcdClient sets etcd client for proxy.
func (node *Proxy) SetEtcdClient(client *clientv3.Client) {
node.etcdCli = client
}
// SetMixCoordClient sets MixCoord client for proxy. // SetMixCoordClient sets MixCoord client for proxy.
func (node *Proxy) SetMixCoordClient(cli types.MixCoordClient) { func (node *Proxy) SetMixCoordClient(cli types.MixCoordClient) {
node.mixCoord = cli node.mixCoord = cli

View File

@ -23,13 +23,11 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"os"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/blang/semver/v4"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -50,12 +48,13 @@ import (
mixc "github.com/milvus-io/milvus/internal/distributed/mixcoord/client" mixc "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode" grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
"github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/distributed/streaming"
grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode"
"github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/testutil" "github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
@ -64,7 +63,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb" "github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/tracer" "github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto" "github.com/milvus-io/milvus/pkg/v2/util/crypto"
@ -89,7 +87,6 @@ func init() {
Registry = prometheus.NewRegistry() Registry = prometheus.NewRegistry()
Registry.MustRegister(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{})) Registry.MustRegister(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{}))
Registry.MustRegister(prometheus.NewGoCollector()) Registry.MustRegister(prometheus.NewGoCollector())
common.Version = semver.MustParse("2.5.9")
} }
func runMixCoord(ctx context.Context, localMsg bool) *grpcmixcoord.Server { func runMixCoord(ctx context.Context, localMsg bool) *grpcmixcoord.Server {
@ -119,6 +116,33 @@ func runMixCoord(ctx context.Context, localMsg bool) *grpcmixcoord.Server {
return rc return rc
} }
func runStreamingNode(ctx context.Context, localMsg bool, alias string) *grpcstreamingnode.Server {
var sn *grpcstreamingnode.Server
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
factory := dependency.MockDefaultFactory(localMsg, Params)
var err error
sn, err = grpcstreamingnode.NewServer(ctx, factory)
if err != nil {
panic(err)
}
if err = sn.Prepare(); err != nil {
panic(err)
}
err = sn.Run()
if err != nil {
panic(err)
}
}()
wg.Wait()
metrics.RegisterStreamingNode(Registry)
return sn
}
func runQueryNode(ctx context.Context, localMsg bool, alias string) *grpcquerynode.Server { func runQueryNode(ctx context.Context, localMsg bool, alias string) *grpcquerynode.Server {
var qn *grpcquerynode.Server var qn *grpcquerynode.Server
var wg sync.WaitGroup var wg sync.WaitGroup
@ -276,17 +300,9 @@ func TestProxy(t *testing.T) {
paramtable.Init() paramtable.Init()
params := paramtable.Get() params := paramtable.Get()
testutil.ResetEnvironment() testutil.ResetEnvironment()
paramtable.SetLocalComponentEnabled(typeutil.StreamingNodeRole)
wal := mock_streaming.NewMockWALAccesser(t) streamingutil.SetStreamingServiceEnabled()
b := mock_streaming.NewMockBroadcast(t) defer streamingutil.UnsetStreamingServiceEnabled()
wal.EXPECT().Broadcast().Return(b).Maybe()
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil).Maybe()
local := mock_streaming.NewMockLocal(t)
local.EXPECT().GetLatestMVCCTimestampIfLocal(mock.Anything, mock.Anything).Return(0, nil).Maybe()
local.EXPECT().GetMetricsIfLocal(mock.Anything).Return(&types.StreamingNodeMetrics{}, nil).Maybe()
wal.EXPECT().Local().Return(local).Maybe()
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
params.RootCoordGrpcServerCfg.IP = "localhost" params.RootCoordGrpcServerCfg.IP = "localhost"
params.QueryCoordGrpcServerCfg.IP = "localhost" params.QueryCoordGrpcServerCfg.IP = "localhost"
@ -295,15 +311,14 @@ func TestProxy(t *testing.T) {
params.QueryNodeGrpcServerCfg.IP = "localhost" params.QueryNodeGrpcServerCfg.IP = "localhost"
params.DataNodeGrpcServerCfg.IP = "localhost" params.DataNodeGrpcServerCfg.IP = "localhost"
params.StreamingNodeGrpcServerCfg.IP = "localhost" params.StreamingNodeGrpcServerCfg.IP = "localhost"
params.Save(params.MQCfg.Type.Key, "pulsar")
path := "/tmp/milvus/rocksmq" + funcutil.GenRandomStr() params.CommonCfg.EnableStorageV2.SwapTempValue("false")
t.Setenv("ROCKSMQ_PATH", path) defer params.CommonCfg.EnableStorageV2.SwapTempValue("")
defer os.RemoveAll(path)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
ctx = GetContext(ctx, "root:123456") ctx = GetContext(ctx, "root:123456")
localMsg := true localMsg := true
factory := dependency.MockDefaultFactory(localMsg, Params) factory := dependency.NewDefaultFactory(false)
alias := "TestProxy" alias := "TestProxy"
log.Info("Initialize parameter table of Proxy") log.Info("Initialize parameter table of Proxy")
@ -314,11 +329,16 @@ func TestProxy(t *testing.T) {
dn := runDataNode(ctx, localMsg, alias) dn := runDataNode(ctx, localMsg, alias)
log.Info("running DataNode ...") log.Info("running DataNode ...")
sn := runStreamingNode(ctx, localMsg, alias)
log.Info("running StreamingNode ...")
qn := runQueryNode(ctx, localMsg, alias) qn := runQueryNode(ctx, localMsg, alias)
log.Info("running QueryNode ...") log.Info("running QueryNode ...")
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
streaming.Init()
proxy, err := NewProxy(ctx, factory) proxy, err := NewProxy(ctx, factory)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, proxy) assert.NotNil(t, proxy)
@ -331,9 +351,11 @@ func TestProxy(t *testing.T) {
Params.EtcdCfg.EtcdTLSKey.GetValue(), Params.EtcdCfg.EtcdTLSKey.GetValue(),
Params.EtcdCfg.EtcdTLSCACert.GetValue(), Params.EtcdCfg.EtcdTLSCACert.GetValue(),
Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
if err != nil {
panic(err)
}
defer etcdcli.Close() defer etcdcli.Close()
assert.NoError(t, err) assert.NoError(t, err)
proxy.SetEtcdClient(etcdcli)
testServer := newProxyTestServer(proxy) testServer := newProxyTestServer(proxy)
wg.Add(1) wg.Add(1)
@ -372,7 +394,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
log.Info("Register proxy done") log.Info("Register proxy done")
defer func() { defer func() {
a := []any{mix, qn, dn, proxy} a := []any{mix, qn, dn, sn, proxy}
fmt.Println(len(a)) fmt.Println(len(a))
// HINT: the order of stopping service refers to the `roles.go` file // HINT: the order of stopping service refers to the `roles.go` file
log.Info("start to stop the services") log.Info("start to stop the services")
@ -1106,6 +1128,10 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
segmentIDs = resp.CollSegIDs[collectionName].Data segmentIDs = resp.CollSegIDs[collectionName].Data
// TODO: Here's a Bug, because a growing segment may cannot be seen right away by mixcoord,
// it can only be seen by streamingnode right away, so we need to check the flush state at streamingnode but not here.
// use timetick for GetFlushState in-future but not segment list.
time.Sleep(5 * time.Second)
log.Info("flush collection", zap.Int64s("segments to be flushed", segmentIDs)) log.Info("flush collection", zap.Int64s("segments to be flushed", segmentIDs))
f := func() bool { f := func() bool {
@ -4792,11 +4818,7 @@ func TestProxy_Import(t *testing.T) {
cache := globalMetaCache cache := globalMetaCache
defer func() { globalMetaCache = cache }() defer func() { globalMetaCache = cache }()
wal := mock_streaming.NewMockWALAccesser(t) streaming.SetupNoopWALForTest()
b := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(b).Maybe()
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
t.Run("Import failed", func(t *testing.T) { t.Run("Import failed", func(t *testing.T) {
proxy := &Proxy{} proxy := &Proxy{}
@ -4825,7 +4847,6 @@ func TestProxy_Import(t *testing.T) {
chMgr.EXPECT().getVChannels(mock.Anything).Return([]string{"foo"}, nil) chMgr.EXPECT().getVChannels(mock.Anything).Return([]string{"foo"}, nil)
proxy.chMgr = chMgr proxy.chMgr = chMgr
factory := dependency.NewDefaultFactory(true)
rc := mocks.NewMockRootCoordClient(t) rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{ rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
ID: rand.Int63(), ID: rand.Int63(),
@ -4839,19 +4860,12 @@ func TestProxy_Import(t *testing.T) {
proxy.tsoAllocator = &timestampAllocator{ proxy.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(), tso: newMockTimestampAllocatorInterface(),
} }
scheduler, err := newTaskScheduler(ctx, proxy.tsoAllocator, factory) scheduler, err := newTaskScheduler(ctx, proxy.tsoAllocator)
assert.NoError(t, err) assert.NoError(t, err)
proxy.sched = scheduler proxy.sched = scheduler
err = proxy.sched.Start() err = proxy.sched.Start()
assert.NoError(t, err) assert.NoError(t, err)
wal := mock_streaming.NewMockWALAccesser(t)
b := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(b)
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil)
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
req := &milvuspb.ImportRequest{ req := &milvuspb.ImportRequest{
CollectionName: "dummy", CollectionName: "dummy",
Files: []string{"a.json"}, Files: []string{"a.json"},

View File

@ -1,72 +0,0 @@
package proxy
import (
"context"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/resource"
)
const (
ReplicateMsgStreamTyp = "replicate_msg_stream"
ReplicateMsgStreamExpireTime = 30 * time.Second
)
type ReplicateStreamManager struct {
ctx context.Context
factory msgstream.Factory
dispatcher msgstream.UnmarshalDispatcher
resourceManager resource.Manager
}
func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, resourceManager resource.Manager) *ReplicateStreamManager {
manager := &ReplicateStreamManager{
ctx: ctx,
factory: factory,
dispatcher: (&msgstream.ProtoUDFactory{}).NewUnmarshalDispatcher(),
resourceManager: resourceManager,
}
return manager
}
func (m *ReplicateStreamManager) newMsgStreamResource(ctx context.Context, channel string) resource.NewResourceFunc {
return func() (resource.Resource, error) {
msgStream, err := m.factory.NewMsgStream(ctx)
if err != nil {
log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err
}
msgStream.SetRepackFunc(replicatePackFunc)
msgStream.AsProducer(ctx, []string{channel})
msgStream.ForceEnableProduce(true)
res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() {
msgStream.Close()
})
return res, nil
}
}
func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) {
ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel))
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(ctx, channel))
if err != nil {
ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err
}
if obj, ok := res.Get().(msgstream.MsgStream); ok && obj != nil {
return obj, nil
}
ctxLog.Warn("invalid resource object", zap.Any("obj", res.Get()))
return nil, merr.ErrInvalidStreamObj
}
func (m *ReplicateStreamManager) GetMsgDispatcher() msgstream.UnmarshalDispatcher {
return m.dispatcher
}

View File

@ -1,85 +0,0 @@
package proxy
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/resource"
)
func TestReplicateManager(t *testing.T) {
factory := newMockMsgStreamFactory()
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
{
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock msgstream fail")
}
_, err := manager.GetReplicateMsgStream(context.Background(), "test")
assert.Error(t, err)
}
{
mockMsgStream := newMockMsgStream()
i := 0
mockMsgStream.setRepack = func(repackFunc msgstream.RepackFunc) {
i++
}
mockMsgStream.asProducer = func(producers []string) {
i++
}
mockMsgStream.forceEnableProduce = func(b bool) {
i++
}
mockMsgStream.close = func() {
i++
}
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return mockMsgStream, nil
}
_, err := manager.GetReplicateMsgStream(context.Background(), "test")
assert.NoError(t, err)
assert.Equal(t, 3, i)
time.Sleep(time.Second)
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
assert.NoError(t, err)
assert.Equal(t, 3, i)
res := resourceManager.Delete(ReplicateMsgStreamTyp, "test")
assert.NotNil(t, res)
assert.Eventually(t, func() bool {
return resourceManager.Delete(ReplicateMsgStreamTyp, "test") == nil
}, time.Second*4, time.Millisecond*500)
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
assert.NoError(t, err)
assert.Equal(t, 7, i)
}
{
res := resourceManager.Delete(ReplicateMsgStreamTyp, "test")
assert.NotNil(t, res)
assert.Eventually(t, func() bool {
return resourceManager.Delete(ReplicateMsgStreamTyp, "test") == nil
}, time.Second*4, time.Millisecond*500)
res, err := resourceManager.Get(ReplicateMsgStreamTyp, "test", func() (resource.Resource, error) {
return resource.NewResource(resource.WithObj("str")), nil
})
assert.NoError(t, err)
assert.Equal(t, "str", res.Get())
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
assert.ErrorIs(t, err, merr.ErrInvalidStreamObj)
}
{
assert.NotNil(t, manager.GetMsgDispatcher())
}
}

View File

@ -1,401 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"container/list"
"context"
"fmt"
"strconv"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const (
segCountPerRPC = 20000
)
// DataCoord is a narrowed interface of DataCoordinator which only provide AssignSegmentID method
type DataCoord interface {
AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error)
}
type segRequest struct {
allocator.BaseRequest
count uint32
collID UniqueID
partitionID UniqueID
segInfo map[UniqueID]uint32
channelName string
timestamp Timestamp
}
type segInfo struct {
segID UniqueID
count uint32
expireTime Timestamp
}
type assignInfo struct {
collID UniqueID
partitionID UniqueID
channelName string
segInfos *list.List
lastInsertTime time.Time
}
func (info *segInfo) IsExpired(ts Timestamp) bool {
return ts > info.expireTime || info.count <= 0
}
func (info *segInfo) Capacity(ts Timestamp) uint32 {
if info.IsExpired(ts) {
return 0
}
return info.count
}
func (info *segInfo) Assign(ts Timestamp, count uint32) uint32 {
if info.IsExpired(ts) {
log.Ctx(context.TODO()).Debug("segInfo Assign IsExpired", zap.Uint64("ts", ts),
zap.Uint32("count", count))
return 0
}
ret := uint32(0)
if info.count >= count {
info.count -= count
ret = count
} else {
ret = info.count
info.count = 0
}
return ret
}
func (info *assignInfo) RemoveExpired(ts Timestamp) {
var next *list.Element
for e := info.segInfos.Front(); e != nil; e = next {
next = e.Next()
segInfo, ok := e.Value.(*segInfo)
if !ok {
log.Warn("can not cast to segInfo")
continue
}
if segInfo.IsExpired(ts) {
info.segInfos.Remove(e)
}
}
}
func (info *assignInfo) Capacity(ts Timestamp) uint32 {
ret := uint32(0)
for e := info.segInfos.Front(); e != nil; e = e.Next() {
segInfo := e.Value.(*segInfo)
ret += segInfo.Capacity(ts)
}
return ret
}
func (info *assignInfo) Assign(ts Timestamp, count uint32) (map[UniqueID]uint32, error) {
capacity := info.Capacity(ts)
if capacity < count {
errMsg := fmt.Sprintf("AssignSegment Failed: capacity:%d is less than count:%d", capacity, count)
return nil, errors.New(errMsg)
}
result := make(map[UniqueID]uint32)
for e := info.segInfos.Front(); e != nil && count != 0; e = e.Next() {
segInfo := e.Value.(*segInfo)
cur := segInfo.Assign(ts, count)
count -= cur
if cur > 0 {
result[segInfo.segID] += cur
}
}
return result, nil
}
type segIDAssigner struct {
allocator.CachedAllocator
assignInfos map[UniqueID]*list.List // collectionID -> *list.List
segReqs []*datapb.SegmentIDRequest
getTickFunc func() Timestamp
PeerID UniqueID
dataCoord DataCoord
countPerRPC uint32
}
// newSegIDAssigner creates a new segIDAssigner
func newSegIDAssigner(ctx context.Context, dataCoord DataCoord, getTickFunc func() Timestamp) (*segIDAssigner, error) {
ctx1, cancel := context.WithCancel(ctx)
sa := &segIDAssigner{
CachedAllocator: allocator.CachedAllocator{
Ctx: ctx1,
CancelFunc: cancel,
Role: "SegmentIDAllocator",
},
countPerRPC: segCountPerRPC,
dataCoord: dataCoord,
assignInfos: make(map[UniqueID]*list.List),
getTickFunc: getTickFunc,
}
sa.TChan = &allocator.Ticker{
UpdateInterval: time.Second,
}
sa.CachedAllocator.SyncFunc = sa.syncSegments
sa.CachedAllocator.ProcessFunc = sa.processFunc
sa.CachedAllocator.CheckSyncFunc = sa.checkSyncFunc
sa.CachedAllocator.PickCanDoFunc = sa.pickCanDoFunc
sa.Init()
return sa, nil
}
func (sa *segIDAssigner) collectExpired() {
ts := sa.getTickFunc()
var next *list.Element
for _, info := range sa.assignInfos {
for e := info.Front(); e != nil; e = next {
next = e.Next()
assign := e.Value.(*assignInfo)
assign.RemoveExpired(ts)
if assign.Capacity(ts) == 0 {
info.Remove(e)
}
}
}
}
func (sa *segIDAssigner) pickCanDoFunc() {
if sa.ToDoReqs == nil {
return
}
records := make(map[UniqueID]map[UniqueID]map[string]uint32)
var newTodoReqs []allocator.Request
for _, req := range sa.ToDoReqs {
segRequest := req.(*segRequest)
collID := segRequest.collID
partitionID := segRequest.partitionID
channelName := segRequest.channelName
if _, ok := records[collID]; !ok {
records[collID] = make(map[UniqueID]map[string]uint32)
}
if _, ok := records[collID][partitionID]; !ok {
records[collID][partitionID] = make(map[string]uint32)
}
if _, ok := records[collID][partitionID][channelName]; !ok {
records[collID][partitionID][channelName] = 0
}
records[collID][partitionID][channelName] += segRequest.count
assign, err := sa.getAssign(segRequest.collID, segRequest.partitionID, segRequest.channelName)
if err != nil || assign.Capacity(segRequest.timestamp) < records[collID][partitionID][channelName] {
sa.segReqs = append(sa.segReqs, &datapb.SegmentIDRequest{
ChannelName: channelName,
Count: segRequest.count,
CollectionID: collID,
PartitionID: partitionID,
})
newTodoReqs = append(newTodoReqs, req)
} else {
sa.CanDoReqs = append(sa.CanDoReqs, req)
}
}
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner pickCanDoFunc", zap.Any("records", records),
zap.Int("len(newTodoReqs)", len(newTodoReqs)),
zap.Int("len(CanDoReqs)", len(sa.CanDoReqs)))
sa.ToDoReqs = newTodoReqs
}
func (sa *segIDAssigner) getAssign(collID UniqueID, partitionID UniqueID, channelName string) (*assignInfo, error) {
assignInfos, ok := sa.assignInfos[collID]
if !ok {
return nil, fmt.Errorf("can not find collection %d", collID)
}
for e := assignInfos.Front(); e != nil; e = e.Next() {
info := e.Value.(*assignInfo)
if info.partitionID != partitionID || info.channelName != channelName {
continue
}
return info, nil
}
return nil, fmt.Errorf("can not find assign info with collID %d, partitionID %d, channelName %s",
collID, partitionID, channelName)
}
func (sa *segIDAssigner) checkSyncFunc(timeout bool) bool {
sa.collectExpired()
return timeout || len(sa.segReqs) != 0
}
func (sa *segIDAssigner) checkSegReqEqual(req1, req2 *datapb.SegmentIDRequest) bool {
if req1 == nil || req2 == nil {
return false
}
if req1 == req2 {
return true
}
return req1.CollectionID == req2.CollectionID && req1.PartitionID == req2.PartitionID && req1.ChannelName == req2.ChannelName
}
func (sa *segIDAssigner) reduceSegReqs() {
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs", zap.Int("len(segReqs)", len(sa.segReqs)))
if len(sa.segReqs) == 0 {
return
}
beforeCnt := uint32(0)
var newSegReqs []*datapb.SegmentIDRequest
for _, req1 := range sa.segReqs {
if req1.Count == 0 {
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs hit perCount == 0")
req1.Count = sa.countPerRPC
}
beforeCnt += req1.Count
var req2 *datapb.SegmentIDRequest
for _, req3 := range newSegReqs {
if sa.checkSegReqEqual(req1, req3) {
req2 = req3
break
}
}
if req2 == nil { // not found
newSegReqs = append(newSegReqs, req1)
} else {
req2.Count += req1.Count
}
}
afterCnt := uint32(0)
for _, req := range newSegReqs {
afterCnt += req.Count
}
sa.segReqs = newSegReqs
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs after reduce", zap.Int("len(segReqs)", len(sa.segReqs)),
zap.Uint32("BeforeCnt", beforeCnt),
zap.Uint32("AfterCnt", afterCnt))
}
func (sa *segIDAssigner) syncSegments() (bool, error) {
if len(sa.segReqs) == 0 {
return true, nil
}
sa.reduceSegReqs()
req := &datapb.AssignSegmentIDRequest{
NodeID: sa.PeerID,
PeerRole: typeutil.ProxyRole,
SegmentIDRequests: sa.segReqs,
}
metrics.ProxySyncSegmentRequestLength.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(len(sa.segReqs)))
sa.segReqs = nil
log.Ctx(context.TODO()).Debug("syncSegments call dataCoord.AssignSegmentID", zap.Stringer("request", req))
resp, err := sa.dataCoord.AssignSegmentID(context.Background(), req)
if err != nil {
return false, fmt.Errorf("syncSegmentID Failed:%w", err)
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return false, fmt.Errorf("syncSegmentID Failed:%s", resp.GetStatus().GetReason())
}
var errMsg string
now := time.Now()
success := true
for _, segAssign := range resp.SegIDAssignments {
if segAssign.Status.GetErrorCode() != commonpb.ErrorCode_Success {
log.Ctx(context.TODO()).Warn("proxy", zap.String("SyncSegment Error", segAssign.GetStatus().GetReason()))
errMsg += segAssign.GetStatus().GetReason()
errMsg += "\n"
success = false
continue
}
assign, err := sa.getAssign(segAssign.CollectionID, segAssign.PartitionID, segAssign.ChannelName)
segInfo2 := &segInfo{
segID: segAssign.SegID,
count: segAssign.Count,
expireTime: segAssign.ExpireTime,
}
if err != nil {
colInfos, ok := sa.assignInfos[segAssign.CollectionID]
if !ok {
colInfos = list.New()
}
segInfos := list.New()
segInfos.PushBack(segInfo2)
assign = &assignInfo{
collID: segAssign.CollectionID,
partitionID: segAssign.PartitionID,
channelName: segAssign.ChannelName,
segInfos: segInfos,
}
colInfos.PushBack(assign)
sa.assignInfos[segAssign.CollectionID] = colInfos
} else {
assign.segInfos.PushBack(segInfo2)
}
assign.lastInsertTime = now
}
if !success {
return false, errors.New(errMsg)
}
return success, nil
}
func (sa *segIDAssigner) processFunc(req allocator.Request) error {
segRequest := req.(*segRequest)
assign, err := sa.getAssign(segRequest.collID, segRequest.partitionID, segRequest.channelName)
if err != nil {
return err
}
result, err2 := assign.Assign(segRequest.timestamp, segRequest.count)
segRequest.segInfo = result
return err2
}
func (sa *segIDAssigner) GetSegmentID(collID UniqueID, partitionID UniqueID, channelName string, count uint32, ts Timestamp) (map[UniqueID]uint32, error) {
req := &segRequest{
BaseRequest: allocator.BaseRequest{Done: make(chan error), Valid: false},
collID: collID,
partitionID: partitionID,
channelName: channelName,
count: count,
timestamp: ts,
}
sa.Reqs <- req
if err := req.Wait(); err != nil {
return nil, fmt.Errorf("getSegmentID failed: %s", err)
}
return req.segInfo, nil
}

View File

@ -1,307 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"math/rand"
"sync"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type mockDataCoord struct {
expireTime Timestamp
}
func (mockD *mockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests))
maxPerCnt := 100
for _, r := range req.SegmentIDRequests {
totalCnt := uint32(0)
for totalCnt != r.Count {
cnt := uint32(rand.Intn(maxPerCnt))
if totalCnt+cnt > r.Count {
cnt = r.Count - totalCnt
}
totalCnt += cnt
result := &datapb.SegmentIDAssignment{
SegID: 1,
ChannelName: r.ChannelName,
Count: cnt,
CollectionID: r.CollectionID,
PartitionID: r.PartitionID,
ExpireTime: mockD.expireTime,
Status: merr.Success(),
}
assigns = append(assigns, result)
}
}
return &datapb.AssignSegmentIDResponse{
Status: merr.Success(),
SegIDAssignments: assigns,
}, nil
}
type mockDataCoord2 struct {
expireTime Timestamp
}
func (mockD *mockDataCoord2) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
return &datapb.AssignSegmentIDResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Just For Test",
},
}, nil
}
func getLastTick1() Timestamp {
return 1000
}
func TestSegmentAllocator1(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(1000)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick1)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
total := uint32(0)
collNames := []string{"abc", "cba"}
for i := 0; i < 10; i++ {
colName := collNames[i%2]
ret, err := segAllocator.GetSegmentID(1, 1, colName, 1, 1)
assert.NoError(t, err)
total += ret[1]
}
assert.Equal(t, uint32(10), total)
ret, err := segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, 999)
assert.NoError(t, err)
assert.Equal(t, uint32(segCountPerRPC-10), ret[1])
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 1001)
assert.Error(t, err)
wg.Wait()
}
var curLastTick2 = Timestamp(200)
var curLastTIck2Lock sync.Mutex
func getLastTick2() Timestamp {
curLastTIck2Lock.Lock()
defer curLastTIck2Lock.Unlock()
curLastTick2 += 100
return curLastTick2
}
func TestSegmentAllocator2(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
segAllocator.Start()
defer segAllocator.Close()
total := uint32(0)
for i := 0; i < 10; i++ {
ret, err := segAllocator.GetSegmentID(1, 1, "abc", 1, 200)
assert.NoError(t, err)
total += ret[1]
}
assert.Equal(t, uint32(10), total)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, getLastTick2())
assert.Error(t, err)
}
func TestSegmentAllocator3(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord2{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
assert.Error(t, err)
wg.Wait()
}
type mockDataCoord3 struct {
expireTime Timestamp
}
func (mockD *mockDataCoord3) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests))
for i, r := range req.SegmentIDRequests {
errCode := commonpb.ErrorCode_Success
reason := ""
if i == 0 {
errCode = commonpb.ErrorCode_UnexpectedError
reason = "Just for test"
}
result := &datapb.SegmentIDAssignment{
SegID: 1,
ChannelName: r.ChannelName,
Count: r.Count,
CollectionID: r.CollectionID,
PartitionID: r.PartitionID,
ExpireTime: mockD.expireTime,
Status: &commonpb.Status{
ErrorCode: errCode,
Reason: reason,
},
}
assigns = append(assigns, result)
}
return &datapb.AssignSegmentIDResponse{
Status: merr.Success(),
SegIDAssignments: assigns,
}, nil
}
func TestSegmentAllocator4(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord3{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
assert.Error(t, err)
wg.Wait()
}
type mockDataCoord5 struct {
expireTime Timestamp
}
func (mockD *mockDataCoord5) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
return &datapb.AssignSegmentIDResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Just For Test",
},
}, errors.New("just for test")
}
func TestSegmentAllocator5(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord5{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
assert.Error(t, err)
wg.Wait()
}
func TestSegmentAllocator6(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
success := true
var sucLock sync.Mutex
collNames := []string{"abc", "cba"}
reqFunc := func(i int, group *sync.WaitGroup) {
defer group.Done()
sucLock.Lock()
defer sucLock.Unlock()
if !success {
return
}
colName := collNames[i%2]
count := uint32(10)
if i == 0 {
count = 0
}
_, err = segAllocator.GetSegmentID(1, 1, colName, count, 100)
if err != nil {
t.Log(err)
success = false
}
}
for i := 0; i < 10; i++ {
wg.Add(1)
go reqFunc(i, wg)
}
wg.Wait()
assert.True(t, success)
}

View File

@ -587,7 +587,6 @@ type dropCollectionTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
chMgr channelsMgr chMgr channelsMgr
chTicker channelsTimeTicker
} }
func (t *dropCollectionTask) TraceCtx() context.Context { func (t *dropCollectionTask) TraceCtx() context.Context {
@ -1047,10 +1046,9 @@ type alterCollectionTask struct {
baseTask baseTask
Condition Condition
*milvuspb.AlterCollectionRequest *milvuspb.AlterCollectionRequest
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
replicateMsgStream msgstream.MsgStream
} }
func (t *alterCollectionTask) TraceCtx() context.Context { func (t *alterCollectionTask) TraceCtx() context.Context {
@ -1266,7 +1264,6 @@ func (t *alterCollectionTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterCollectionRequest)
return nil return nil
} }
@ -1278,10 +1275,9 @@ type alterCollectionFieldTask struct {
baseTask baseTask
Condition Condition
*milvuspb.AlterCollectionFieldRequest *milvuspb.AlterCollectionFieldRequest
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
replicateMsgStream msgstream.MsgStream
} }
func (t *alterCollectionFieldTask) TraceCtx() context.Context { func (t *alterCollectionFieldTask) TraceCtx() context.Context {
@ -1495,7 +1491,6 @@ func (t *alterCollectionFieldTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterCollectionFieldRequest)
return nil return nil
} }
@ -1938,8 +1933,7 @@ type loadCollectionTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
collectionID UniqueID collectionID UniqueID
replicateMsgStream msgstream.MsgStream
} }
func (t *loadCollectionTask) TraceCtx() context.Context { func (t *loadCollectionTask) TraceCtx() context.Context {
@ -2088,7 +2082,6 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return fmt.Errorf("call query coordinator LoadCollection: %s", err) return fmt.Errorf("call query coordinator LoadCollection: %s", err)
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadCollectionRequest)
return nil return nil
} }
@ -2111,8 +2104,7 @@ type releaseCollectionTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
collectionID UniqueID collectionID UniqueID
replicateMsgStream msgstream.MsgStream
} }
func (t *releaseCollectionTask) TraceCtx() context.Context { func (t *releaseCollectionTask) TraceCtx() context.Context {
@ -2186,7 +2178,6 @@ func (t *releaseCollectionTask) Execute(ctx context.Context) (err error) {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleaseCollectionRequest)
return nil return nil
} }
@ -2202,8 +2193,7 @@ type loadPartitionsTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
collectionID UniqueID collectionID UniqueID
replicateMsgStream msgstream.MsgStream
} }
func (t *loadPartitionsTask) TraceCtx() context.Context { func (t *loadPartitionsTask) TraceCtx() context.Context {
@ -2362,7 +2352,6 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadPartitionsRequest)
return nil return nil
} }
@ -2379,8 +2368,7 @@ type releasePartitionsTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
collectionID UniqueID collectionID UniqueID
replicateMsgStream msgstream.MsgStream
} }
func (t *releasePartitionsTask) TraceCtx() context.Context { func (t *releasePartitionsTask) TraceCtx() context.Context {
@ -2469,7 +2457,6 @@ func (t *releasePartitionsTask) Execute(ctx context.Context) (err error) {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleasePartitionsRequest)
return nil return nil
} }

View File

@ -22,7 +22,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -33,10 +32,9 @@ type CreateAliasTask struct {
baseTask baseTask
Condition Condition
*milvuspb.CreateAliasRequest *milvuspb.CreateAliasRequest
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
replicateMsgStream msgstream.MsgStream result *commonpb.Status
result *commonpb.Status
} }
// TraceCtx returns the trace context of the task. // TraceCtx returns the trace context of the task.
@ -111,7 +109,6 @@ func (t *CreateAliasTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.CreateAliasRequest)
return nil return nil
} }
@ -125,10 +122,9 @@ type DropAliasTask struct {
baseTask baseTask
Condition Condition
*milvuspb.DropAliasRequest *milvuspb.DropAliasRequest
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
replicateMsgStream msgstream.MsgStream result *commonpb.Status
result *commonpb.Status
} }
// TraceCtx returns the context for trace // TraceCtx returns the context for trace
@ -189,7 +185,6 @@ func (t *DropAliasTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.DropAliasRequest)
return nil return nil
} }
@ -202,10 +197,9 @@ type AlterAliasTask struct {
baseTask baseTask
Condition Condition
*milvuspb.AlterAliasRequest *milvuspb.AlterAliasRequest
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
replicateMsgStream msgstream.MsgStream result *commonpb.Status
result *commonpb.Status
} }
func (t *AlterAliasTask) TraceCtx() context.Context { func (t *AlterAliasTask) TraceCtx() context.Context {
@ -270,7 +264,6 @@ func (t *AlterAliasTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterAliasRequest)
return nil return nil
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -25,8 +24,6 @@ type createDatabaseTask struct {
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
replicateMsgStream msgstream.MsgStream
} }
func (cdt *createDatabaseTask) TraceCtx() context.Context { func (cdt *createDatabaseTask) TraceCtx() context.Context {
@ -78,9 +75,6 @@ func (cdt *createDatabaseTask) Execute(ctx context.Context) error {
var err error var err error
cdt.result, err = cdt.mixCoord.CreateDatabase(ctx, cdt.CreateDatabaseRequest) cdt.result, err = cdt.mixCoord.CreateDatabase(ctx, cdt.CreateDatabaseRequest)
err = merr.CheckRPCCall(cdt.result, err) err = merr.CheckRPCCall(cdt.result, err)
if err == nil {
SendReplicateMessagePack(ctx, cdt.replicateMsgStream, cdt.CreateDatabaseRequest)
}
return err return err
} }
@ -95,8 +89,6 @@ type dropDatabaseTask struct {
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
replicateMsgStream msgstream.MsgStream
} }
func (ddt *dropDatabaseTask) TraceCtx() context.Context { func (ddt *dropDatabaseTask) TraceCtx() context.Context {
@ -151,7 +143,6 @@ func (ddt *dropDatabaseTask) Execute(ctx context.Context) error {
err = merr.CheckRPCCall(ddt.result, err) err = merr.CheckRPCCall(ddt.result, err)
if err == nil { if err == nil {
globalMetaCache.RemoveDatabase(ctx, ddt.DbName) globalMetaCache.RemoveDatabase(ctx, ddt.DbName)
SendReplicateMessagePack(ctx, ddt.replicateMsgStream, ddt.DropDatabaseRequest)
} }
return err return err
} }
@ -230,8 +221,6 @@ type alterDatabaseTask struct {
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
replicateMsgStream msgstream.MsgStream
} }
func (t *alterDatabaseTask) TraceCtx() context.Context { func (t *alterDatabaseTask) TraceCtx() context.Context {
@ -323,8 +312,6 @@ func (t *alterDatabaseTask) Execute(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterDatabaseRequest)
t.result = ret t.result = ret
return nil return nil
} }

View File

@ -2,13 +2,11 @@ package proxy
import ( import (
"context" "context"
"fmt"
"io" "io"
"strconv" "strconv"
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"go.opentelemetry.io/otel"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -21,7 +19,6 @@ import (
"github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil" "github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
@ -133,60 +130,6 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
return nil return nil
} }
func (dt *deleteTask) Execute(ctx context.Context) (err error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute")
defer sp.End()
if len(dt.req.GetExpr()) == 0 {
return merr.WrapErrParameterInvalid("valid expr", "empty expr", "invalid expression")
}
dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
stream, err := dt.chMgr.getOrCreateDmlStream(ctx, dt.collectionID)
if err != nil {
return err
}
result, numRows, err := repackDeleteMsgByHash(
ctx,
dt.primaryKeys, dt.vChannels,
dt.idAllocator, dt.ts,
dt.collectionID, dt.req.GetCollectionName(),
dt.partitionID, dt.req.GetPartitionName(),
dt.req.GetDbName(),
)
if err != nil {
return err
}
// send delete request to log broker
msgPack := &msgstream.MsgPack{
BeginTs: dt.BeginTs(),
EndTs: dt.EndTs(),
}
for _, msgs := range result {
for _, msg := range msgs {
msgPack.Msgs = append(msgPack.Msgs, msg)
}
}
log.Ctx(ctx).Debug("send delete request to virtual channels",
zap.String("collectionName", dt.req.GetCollectionName()),
zap.Int64("collectionID", dt.collectionID),
zap.Strings("virtual_channels", dt.vChannels),
zap.Int64("taskID", dt.ID()),
zap.Duration("prepare duration", dt.tr.RecordSpan()))
err = stream.Produce(ctx, msgPack)
if err != nil {
return err
}
dt.sessionTS = dt.ts
dt.count += numRows
return nil
}
func (dt *deleteTask) PostExecute(ctx context.Context) error { func (dt *deleteTask) PostExecute(ctx context.Context) error {
metrics.ProxyDeleteVectors.WithLabelValues( metrics.ProxyDeleteVectors.WithLabelValues(
paramtable.GetStringNodeID(), paramtable.GetStringNodeID(),
@ -288,7 +231,6 @@ type deleteRunner struct {
// channel // channel
chMgr channelsMgr chMgr channelsMgr
chTicker channelsTimeTicker
vChannels []vChan vChannels []vChan
idAllocator allocator.Interface idAllocator allocator.Interface
@ -437,20 +379,13 @@ func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs,
req: dr.req, req: dr.req,
idAllocator: dr.idAllocator, idAllocator: dr.idAllocator,
chMgr: dr.chMgr, chMgr: dr.chMgr,
chTicker: dr.chTicker,
collectionID: dr.collectionID, collectionID: dr.collectionID,
partitionID: partitionID, partitionID: partitionID,
vChannels: dr.vChannels, vChannels: dr.vChannels,
primaryKeys: primaryKeys, primaryKeys: primaryKeys,
dbID: dr.dbID, dbID: dr.dbID,
} }
if err := dr.queue.Enqueue(dt); err != nil {
var enqueuedTask task = dt
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &deleteTaskByStreamingService{deleteTask: dt}
}
if err := dr.queue.Enqueue(enqueuedTask); err != nil {
log.Ctx(ctx).Error("Failed to enqueue delete task: " + err.Error()) log.Ctx(ctx).Error("Failed to enqueue delete task: " + err.Error())
return nil, err return nil, err
} }

View File

@ -15,13 +15,9 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
type deleteTaskByStreamingService struct {
*deleteTask
}
// Execute is a function to delete task by streaming service // Execute is a function to delete task by streaming service
// we only overwrite the Execute function // we only overwrite the Execute function
func (dt *deleteTaskByStreamingService) Execute(ctx context.Context) (err error) { func (dt *deleteTask) Execute(ctx context.Context) (err error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute") ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute")
defer sp.End() defer sp.End()

View File

@ -14,11 +14,11 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
@ -152,19 +152,6 @@ func TestDeleteTask_Execute(t *testing.T) {
assert.Error(t, dt.Execute(context.Background())) assert.Error(t, dt.Execute(context.Background()))
}) })
t.Run("get channel failed", func(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
dt := deleteTask{
chMgr: mockMgr,
req: &milvuspb.DeleteRequest{
Expr: "pk in [1,2]",
},
}
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background()))
})
t.Run("alloc failed", func(t *testing.T) { t.Run("alloc failed", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -189,9 +176,6 @@ func TestDeleteTask_Execute(t *testing.T) {
}, },
primaryKeys: pk, primaryKeys: pk,
} }
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
assert.Error(t, dt.Execute(context.Background())) assert.Error(t, dt.Execute(context.Background()))
}) })
@ -225,9 +209,7 @@ func TestDeleteTask_Execute(t *testing.T) {
}, },
primaryKeys: pk, primaryKeys: pk,
} }
stream := msgstream.NewMockMsgStream(t) streaming.ExpectErrorOnce(errors.New("mock error"))
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background())) assert.Error(t, dt.Execute(context.Background()))
}) })
} }
@ -666,7 +648,7 @@ func TestDeleteRunner_Run(t *testing.T) {
tsoAllocator := &mockTsoAllocator{} tsoAllocator := &mockTsoAllocator{}
idAllocator := &mockIDAllocatorInterface{} idAllocator := &mockIDAllocatorInterface{}
queue, err := newTaskScheduler(ctx, tsoAllocator, nil) queue, err := newTaskScheduler(ctx, tsoAllocator)
assert.NoError(t, err) assert.NoError(t, err)
queue.Start() queue.Start()
defer queue.Close() defer queue.Close()
@ -731,11 +713,8 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
plan: plan, plan: plan,
} }
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error")) streaming.ExpectErrorOnce(errors.New("mock error"))
assert.Error(t, dr.Run(context.Background())) assert.Error(t, dr.Run(context.Background()))
assert.Equal(t, int64(0), dr.result.DeleteCnt) assert.Equal(t, int64(0), dr.result.DeleteCnt)
}) })
@ -814,10 +793,7 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
plan: plan, plan: plan,
} }
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -946,8 +922,6 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
plan: plan, plan: plan,
} }
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -971,8 +945,8 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil) server.FinishSend(nil)
return client return client
}, nil) }, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
streaming.ExpectErrorOnce(errors.New("mock error"))
assert.Error(t, dr.Run(ctx)) assert.Error(t, dr.Run(ctx))
assert.Equal(t, int64(0), dr.result.DeleteCnt) assert.Equal(t, int64(0), dr.result.DeleteCnt)
}) })
@ -1012,8 +986,6 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
plan: plan, plan: plan,
} }
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -1037,8 +1009,6 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil) server.FinishSend(nil)
return client return client
}, nil) }, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx)) assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt) assert.Equal(t, int64(3), dr.result.DeleteCnt)
}) })
@ -1088,8 +1058,6 @@ func TestDeleteRunner_Run(t *testing.T) {
}, },
plan: plan, plan: plan,
} }
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil) mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error { lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "") return workload.exec(ctx, 1, qn, "")
@ -1114,7 +1082,6 @@ func TestDeleteRunner_Run(t *testing.T) {
return client return client
}, nil) }, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx)) assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt) assert.Equal(t, int64(3), dr.result.DeleteCnt)
}) })

View File

@ -18,17 +18,11 @@ package proxy
import ( import (
"context" "context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
) )
@ -40,7 +34,7 @@ type flushTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *milvuspb.FlushResponse result *milvuspb.FlushResponse
replicateMsgStream msgstream.MsgStream chMgr channelsMgr
} }
func (t *flushTask) TraceCtx() context.Context { func (t *flushTask) TraceCtx() context.Context {
@ -88,46 +82,6 @@ func (t *flushTask) PreExecute(ctx context.Context) error {
return nil return nil
} }
func (t *flushTask) Execute(ctx context.Context) error {
coll2Segments := make(map[string]*schemapb.LongArray)
flushColl2Segments := make(map[string]*schemapb.LongArray)
coll2SealTimes := make(map[string]int64)
coll2FlushTs := make(map[string]Timestamp)
channelCps := make(map[string]*msgpb.MsgPosition)
for _, collName := range t.CollectionNames {
collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collName)
if err != nil {
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
flushReq := &datapb.FlushRequest{
Base: commonpbutil.UpdateMsgBase(
t.Base,
commonpbutil.WithMsgType(commonpb.MsgType_Flush),
),
CollectionID: collID,
}
resp, err := t.mixCoord.Flush(ctx, flushReq)
if err = merr.CheckRPCCall(resp, err); err != nil {
return fmt.Errorf("failed to call flush to data coordinator: %s", err.Error())
}
coll2Segments[collName] = &schemapb.LongArray{Data: resp.GetSegmentIDs()}
flushColl2Segments[collName] = &schemapb.LongArray{Data: resp.GetFlushSegmentIDs()}
coll2SealTimes[collName] = resp.GetTimeOfSeal()
coll2FlushTs[collName] = resp.GetFlushTs()
channelCps = resp.GetChannelCps()
}
t.result = &milvuspb.FlushResponse{
Status: merr.Success(),
DbName: t.GetDbName(),
CollSegIDs: coll2Segments,
FlushCollSegIDs: flushColl2Segments,
CollSealTimes: coll2SealTimes,
CollFlushTs: coll2FlushTs,
ChannelCps: channelCps,
}
return nil
}
func (t *flushTask) PostExecute(ctx context.Context) error { func (t *flushTask) PostExecute(ctx context.Context) error {
return nil return nil
} }

View File

@ -18,15 +18,12 @@ package proxy
import ( import (
"context" "context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
) )
@ -37,8 +34,7 @@ type flushAllTask struct {
ctx context.Context ctx context.Context
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *datapb.FlushAllResponse result *datapb.FlushAllResponse
chMgr channelsMgr
replicateMsgStream msgstream.MsgStream
} }
func (t *flushAllTask) TraceCtx() context.Context { func (t *flushAllTask) TraceCtx() context.Context {
@ -86,22 +82,6 @@ func (t *flushAllTask) PreExecute(ctx context.Context) error {
return nil return nil
} }
func (t *flushAllTask) Execute(ctx context.Context) error {
flushAllReq := &datapb.FlushAllRequest{
Base: commonpbutil.UpdateMsgBase(
t.Base,
commonpbutil.WithMsgType(commonpb.MsgType_Flush),
),
}
resp, err := t.mixCoord.FlushAll(ctx, flushAllReq)
if err = merr.CheckRPCCall(resp, err); err != nil {
return fmt.Errorf("failed to call flush all to data coordinator: %s", err.Error())
}
t.result = resp
return nil
}
func (t *flushAllTask) PostExecute(ctx context.Context) error { func (t *flushAllTask) PostExecute(ctx context.Context) error {
return nil return nil
} }

View File

@ -30,12 +30,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
) )
type flushAllTaskbyStreamingService struct { func (t *flushAllTask) Execute(ctx context.Context) error {
*flushAllTask
chMgr channelsMgr
}
func (t *flushAllTaskbyStreamingService) Execute(ctx context.Context) error {
dbNames := make([]string, 0) dbNames := make([]string, 0)
if t.GetDbName() != "" { if t.GetDbName() != "" {
dbNames = append(dbNames, t.GetDbName()) dbNames = append(dbNames, t.GetDbName())
@ -60,7 +55,7 @@ func (t *flushAllTaskbyStreamingService) Execute(ctx context.Context) error {
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections)), Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections)),
DbName: dbName, DbName: dbName,
}) })
if err != nil { if err := merr.CheckRPCCall(showColRsp, err); err != nil {
log.Info("flush all task by streaming service failed, show collections failed", zap.String("dbName", dbName), zap.Error(err)) log.Info("flush all task by streaming service failed, show collections failed", zap.String("dbName", dbName), zap.Error(err))
return err return err
} }

View File

@ -33,7 +33,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
) )
func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flushAllTaskbyStreamingService, *mocks.MockMixCoordClient, *msgstream.MockMsgStream, *MockChannelsMgr, context.Context) { func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flushAllTask, *mocks.MockMixCoordClient, *msgstream.MockMsgStream, *MockChannelsMgr, context.Context) {
ctx := context.Background() ctx := context.Background()
mixCoord := mocks.NewMockMixCoordClient(t) mixCoord := mocks.NewMockMixCoordClient(t)
replicateMsgStream := msgstream.NewMockMsgStream(t) replicateMsgStream := msgstream.NewMockMsgStream(t)
@ -51,17 +51,11 @@ func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flu
}, },
DbName: dbName, DbName: dbName,
}, },
ctx: ctx, ctx: ctx,
mixCoord: mixCoord, mixCoord: mixCoord,
replicateMsgStream: replicateMsgStream, chMgr: chMgr,
} }
return baseTask, mixCoord, replicateMsgStream, chMgr, ctx
task := &flushAllTaskbyStreamingService{
flushAllTask: baseTask,
chMgr: chMgr,
}
return task, mixCoord, replicateMsgStream, chMgr, ctx
} }
func TestFlushAllTask_WithSpecificDB(t *testing.T) { func TestFlushAllTask_WithSpecificDB(t *testing.T) {

View File

@ -18,19 +18,15 @@ package proxy
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/uniquegenerator" "github.com/milvus-io/milvus/pkg/v2/util/uniquegenerator"
) )
@ -50,9 +46,8 @@ func createTestFlushAllTask(t *testing.T) (*flushAllTask, *mocks.MockMixCoordCli
SourceID: 1, SourceID: 1,
}, },
}, },
ctx: ctx, ctx: ctx,
mixCoord: mixCoord, mixCoord: mixCoord,
replicateMsgStream: replicateMsgStream,
} }
return task, mixCoord, replicateMsgStream, ctx return task, mixCoord, replicateMsgStream, ctx
@ -151,95 +146,6 @@ func TestFlushAllTaskPreExecute(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestFlushAllTaskExecuteSuccess(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Setup expectations
expectedResp := &datapb.FlushAllResponse{
Status: merr.Success(),
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(expectedResp, nil).Once()
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, expectedResp, task.result)
}
func TestFlushAllTaskExecuteFlushAllRPCError(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test RPC call error
expectedErr := fmt.Errorf("rpc error")
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(nil, expectedErr).Once()
err := task.Execute(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to call flush all to data coordinator")
}
func TestFlushAllTaskExecuteFlushAllResponseError(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test response with error status
errorResp := &datapb.FlushAllResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "flush all failed",
},
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(errorResp, nil).Once()
err := task.Execute(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to call flush all to data coordinator")
}
func TestFlushAllTaskExecuteWithMerCheck(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test successful execution with merr.CheckRPCCall
successResp := &datapb.FlushAllResponse{
Status: merr.Success(),
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(successResp, nil).Once()
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, successResp, task.result)
}
func TestFlushAllTaskExecuteRequestContent(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test the content of the FlushAllRequest sent to mixCoord
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(&datapb.FlushAllResponse{Status: merr.Success()}, nil).Once()
err := task.Execute(ctx)
assert.NoError(t, err)
// The test verifies that Execute method creates the correct request structure internally
// The actual request content validation is covered by other tests
}
func TestFlushAllTaskPostExecute(t *testing.T) { func TestFlushAllTaskPostExecute(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t) task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t) defer mixCoord.AssertExpectations(t)
@ -249,97 +155,6 @@ func TestFlushAllTaskPostExecute(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestFlushAllTaskLifecycle(t *testing.T) {
ctx := context.Background()
mixCoord := mocks.NewMockMixCoordClient(t)
replicateMsgStream := msgstream.NewMockMsgStream(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test complete task lifecycle
// 1. OnEnqueue
task := &flushAllTask{
baseTask: baseTask{},
Condition: NewTaskCondition(ctx),
FlushAllRequest: &milvuspb.FlushAllRequest{},
ctx: ctx,
mixCoord: mixCoord,
replicateMsgStream: replicateMsgStream,
}
err := task.OnEnqueue()
assert.NoError(t, err)
// 2. PreExecute
err = task.PreExecute(ctx)
assert.NoError(t, err)
// 3. Execute
expectedResp := &datapb.FlushAllResponse{
Status: merr.Success(),
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(expectedResp, nil).Once()
err = task.Execute(ctx)
assert.NoError(t, err)
// 4. PostExecute
err = task.PostExecute(ctx)
assert.NoError(t, err)
// Verify task state
assert.Equal(t, expectedResp, task.result)
}
func TestFlushAllTaskErrorHandlingInExecute(t *testing.T) {
// Test different error scenarios in Execute method
testCases := []struct {
name string
setupMock func(*mocks.MockMixCoordClient)
expectedError string
}{
{
name: "mixCoord FlushAll returns error",
setupMock: func(mixCoord *mocks.MockMixCoordClient) {
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(nil, fmt.Errorf("network error")).Once()
},
expectedError: "failed to call flush all to data coordinator",
},
{
name: "mixCoord FlushAll returns error status",
setupMock: func(mixCoord *mocks.MockMixCoordClient) {
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(&datapb.FlushAllResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_IllegalArgument,
Reason: "invalid request",
},
}, nil).Once()
},
expectedError: "failed to call flush all to data coordinator",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
tc.setupMock(mixCoord)
err := task.Execute(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
})
}
}
func TestFlushAllTaskImplementsTaskInterface(t *testing.T) { func TestFlushAllTaskImplementsTaskInterface(t *testing.T) {
// Verify that flushAllTask implements the task interface // Verify that flushAllTask implements the task interface
var _ task = (*flushAllTask)(nil) var _ task = (*flushAllTask)(nil)

View File

@ -35,12 +35,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
) )
type flushTaskByStreamingService struct { func (t *flushTask) Execute(ctx context.Context) error {
*flushTask
chMgr channelsMgr
}
func (t *flushTaskByStreamingService) Execute(ctx context.Context) error {
coll2Segments := make(map[string]*schemapb.LongArray) coll2Segments := make(map[string]*schemapb.LongArray)
flushColl2Segments := make(map[string]*schemapb.LongArray) flushColl2Segments := make(map[string]*schemapb.LongArray)
coll2SealTimes := make(map[string]int64) coll2SealTimes := make(map[string]int64)

View File

@ -32,7 +32,6 @@ import (
"github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb" "github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -64,8 +63,6 @@ type createIndexTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
replicateMsgStream msgstream.MsgStream
isAutoIndex bool isAutoIndex bool
newIndexParams []*commonpb.KeyValuePair newIndexParams []*commonpb.KeyValuePair
newTypeParams []*commonpb.KeyValuePair newTypeParams []*commonpb.KeyValuePair
@ -580,7 +577,6 @@ func (cit *createIndexTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(cit.result, err); err != nil { if err = merr.CheckRPCCall(cit.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, cit.replicateMsgStream, cit.req)
return nil return nil
} }
@ -596,8 +592,6 @@ type alterIndexTask struct {
mixCoord types.MixCoordClient mixCoord types.MixCoordClient
result *commonpb.Status result *commonpb.Status
replicateMsgStream msgstream.MsgStream
collectionID UniqueID collectionID UniqueID
} }
@ -711,7 +705,6 @@ func (t *alterIndexTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil { if err = merr.CheckRPCCall(t.result, err); err != nil {
return err return err
} }
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.req)
return nil return nil
} }
@ -978,8 +971,6 @@ type dropIndexTask struct {
result *commonpb.Status result *commonpb.Status
collectionID UniqueID collectionID UniqueID
replicateMsgStream msgstream.MsgStream
} }
func (dit *dropIndexTask) TraceCtx() context.Context { func (dit *dropIndexTask) TraceCtx() context.Context {
@ -1073,7 +1064,6 @@ func (dit *dropIndexTask) Execute(ctx context.Context) error {
ctxLog.Warn("drop index failed", zap.Error(err)) ctxLog.Warn("drop index failed", zap.Error(err))
return err return err
} }
SendReplicateMessagePack(ctx, dit.replicateMsgStream, dit.DropIndexRequest)
return nil return nil
} }

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/indexparamcheck"
@ -46,6 +47,7 @@ import (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
paramtable.Init() paramtable.Init()
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
streaming.SetupNoopWALForTest()
code := m.Run() code := m.Run()
os.Exit(code) os.Exit(code)
} }

View File

@ -2,7 +2,6 @@ package proxy
import ( import (
"context" "context"
"fmt"
"strconv" "strconv"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
@ -15,7 +14,6 @@ import (
"github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -32,9 +30,7 @@ type insertTask struct {
result *milvuspb.MutationResult result *milvuspb.MutationResult
idAllocator *allocator.IDAllocator idAllocator *allocator.IDAllocator
segIDAssigner *segIDAssigner
chMgr channelsMgr chMgr channelsMgr
chTicker channelsTimeTicker
vChannels []vChan vChannels []vChan
pChannels []pChan pChannels []pChan
schema *schemapb.CollectionSchema schema *schemapb.CollectionSchema
@ -292,76 +288,6 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return nil return nil
} }
func (it *insertTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute")
defer sp.End()
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID()))
collectionName := it.insertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.GetDbName(), collectionName)
log := log.Ctx(ctx)
if err != nil {
log.Warn("fail to get collection id", zap.Error(err))
return err
}
it.insertMsg.CollectionID = collID
getCacheDur := tr.RecordSpan()
stream, err := it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil {
return err
}
getMsgStreamDur := tr.RecordSpan()
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Warn("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
log.Debug("send insert request to virtual channels",
zap.String("partition", it.insertMsg.GetPartitionName()),
zap.Int64("collectionID", collID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", it.ID()),
zap.Duration("get cache duration", getCacheDur),
zap.Duration("get msgStream duration", getMsgStreamDur))
// assign segmentID for insert data and repack data by segmentID
var msgPack *msgstream.MsgPack
if it.partitionKeys == nil {
msgPack, err = repackInsertData(it.TraceCtx(), channelNames, it.insertMsg, it.result, it.idAllocator, it.segIDAssigner)
} else {
msgPack, err = repackInsertDataWithPartitionKey(it.TraceCtx(), channelNames, it.partitionKeys, it.insertMsg, it.result, it.idAllocator, it.segIDAssigner)
}
if err != nil {
log.Warn("assign segmentID and repack insert data failed", zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
assignSegmentIDDur := tr.RecordSpan()
log.Debug("assign segmentID for insert data success",
zap.Duration("assign segmentID duration", assignSegmentIDDur))
err = stream.Produce(ctx, msgPack)
if err != nil {
log.Warn("fail to produce insert msg", zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
sendMsgDur := tr.RecordSpan()
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
totalExecDur := tr.ElapseSpan()
log.Debug("Proxy Insert Execute done",
zap.String("collectionName", collectionName),
zap.Duration("send message duration", sendMsgDur),
zap.Duration("execute duration", totalExecDur))
return nil
}
func (it *insertTask) PostExecute(ctx context.Context) error { func (it *insertTask) PostExecute(ctx context.Context) error {
return nil return nil
} }

View File

@ -18,12 +18,8 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
type insertTaskByStreamingService struct {
*insertTask
}
// we only overwrite the Execute function // we only overwrite the Execute function
func (it *insertTaskByStreamingService) Execute(ctx context.Context) error { func (it *insertTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute") ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute")
defer sp.End() defer sp.End()

View File

@ -29,7 +29,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
@ -437,22 +436,18 @@ type taskScheduler struct {
wg sync.WaitGroup wg sync.WaitGroup
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
msFactory msgstream.Factory
} }
type schedOpt func(*taskScheduler) type schedOpt func(*taskScheduler)
func newTaskScheduler(ctx context.Context, func newTaskScheduler(ctx context.Context,
tsoAllocatorIns tsoAllocator, tsoAllocatorIns tsoAllocator,
factory msgstream.Factory,
opts ...schedOpt, opts ...schedOpt,
) (*taskScheduler, error) { ) (*taskScheduler, error) {
ctx1, cancel := context.WithCancel(ctx) ctx1, cancel := context.WithCancel(ctx)
s := &taskScheduler{ s := &taskScheduler{
ctx: ctx1, ctx: ctx1,
cancel: cancel, cancel: cancel,
msFactory: factory,
} }
s.ddQueue = newDdTaskQueue(tsoAllocatorIns) s.ddQueue = newDdTaskQueue(tsoAllocatorIns)
s.dmQueue = newDmTaskQueue(tsoAllocatorIns) s.dmQueue = newDmTaskQueue(tsoAllocatorIns)

View File

@ -492,9 +492,8 @@ func TestTaskScheduler(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tsoAllocatorIns := newMockTsoAllocator() tsoAllocatorIns := newMockTsoAllocator()
factory := newSimpleMockMsgStreamFactory()
sched, err := newTaskScheduler(ctx, tsoAllocatorIns, factory) sched, err := newTaskScheduler(ctx, tsoAllocatorIns)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, sched) assert.NotNil(t, sched)
@ -572,8 +571,7 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
).Return(collectionID, nil) ).Return(collectionID, nil)
globalMetaCache = cache globalMetaCache = cache
tsoAllocatorIns := newMockTsoAllocator() tsoAllocatorIns := newMockTsoAllocator()
factory := newSimpleMockMsgStreamFactory() scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns)
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns, factory)
assert.NoError(t, err) assert.NoError(t, err)
run := func(wg *sync.WaitGroup) { run := func(wg *sync.WaitGroup) {

View File

@ -3586,7 +3586,7 @@ func TestSearchTask_Requery(t *testing.T) {
node.tsoAllocator = &timestampAllocator{ node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(), tso: newMockTimestampAllocatorInterface(),
} }
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory) scheduler, err := newTaskScheduler(ctx, node.tsoAllocator)
assert.NoError(t, err) assert.NoError(t, err)
node.sched = scheduler node.sched = scheduler
err = node.sched.Start() err = node.sched.Start()

View File

@ -2161,12 +2161,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc) dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID) pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -2182,11 +2177,6 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
_ = idAllocator.Start() _ = idAllocator.Start()
defer idAllocator.Close() defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("insert", func(t *testing.T) { t.Run("insert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb) hash := testutils.GenerateHashKeys(nb)
task := &insertTask{ task := &insertTask{
@ -2221,13 +2211,11 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
UpsertCnt: 0, UpsertCnt: 0,
Timestamp: 0, Timestamp: 0,
}, },
idAllocator: idAllocator, idAllocator: idAllocator,
segIDAssigner: segAllocator, chMgr: chMgr,
chMgr: chMgr, vChannels: nil,
chTicker: ticker, pChannels: nil,
vChannels: nil, schema: nil,
pChannels: nil,
schema: nil,
} }
for fieldName, dataType := range fieldName2Types { for fieldName, dataType := range fieldName2Types {
@ -2403,12 +2391,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, mixc) dmlChannelsFunc := getDmlChannelsFunc(ctx, mixc)
factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID) pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -2424,12 +2407,6 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
_ = idAllocator.Start() _ = idAllocator.Start()
defer idAllocator.Close() defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("insert", func(t *testing.T) { t.Run("insert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb) hash := testutils.GenerateHashKeys(nb)
task := &insertTask{ task := &insertTask{
@ -2464,13 +2441,11 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
UpsertCnt: 0, UpsertCnt: 0,
Timestamp: 0, Timestamp: 0,
}, },
idAllocator: idAllocator, idAllocator: idAllocator,
segIDAssigner: segAllocator, chMgr: chMgr,
chMgr: chMgr, vChannels: nil,
chTicker: ticker, pChannels: nil,
vChannels: nil, schema: nil,
pChannels: nil,
schema: nil,
} }
fieldID := common.StartOfUserFieldID fieldID := common.StartOfUserFieldID
@ -2549,13 +2524,12 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
UpsertCnt: 0, UpsertCnt: 0,
Timestamp: 0, Timestamp: 0,
}, },
idAllocator: idAllocator, idAllocator: idAllocator,
segIDAssigner: segAllocator, chMgr: chMgr,
chMgr: chMgr, chTicker: ticker,
chTicker: ticker, vChannels: nil,
vChannels: nil, pChannels: nil,
pChannels: nil, schema: nil,
schema: nil,
} }
fieldID := common.StartOfUserFieldID fieldID := common.StartOfUserFieldID
@ -3817,12 +3791,7 @@ func TestPartitionKey(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc) dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID) pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -3838,12 +3807,6 @@ func TestPartitionKey(t *testing.T) {
_ = idAllocator.Start() _ = idAllocator.Start()
defer idAllocator.Close() defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", collectionName) partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", collectionName)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames))) assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames)))
@ -3888,13 +3851,11 @@ func TestPartitionKey(t *testing.T) {
UpsertCnt: 0, UpsertCnt: 0,
Timestamp: 0, Timestamp: 0,
}, },
idAllocator: idAllocator, idAllocator: idAllocator,
segIDAssigner: segAllocator, chMgr: chMgr,
chMgr: chMgr, vChannels: nil,
chTicker: ticker, pChannels: nil,
vChannels: nil, schema: nil,
pChannels: nil,
schema: nil,
} }
// don't support specify partition name if use partition key // don't support specify partition name if use partition key
@ -3932,10 +3893,9 @@ func TestPartitionKey(t *testing.T) {
IdField: nil, IdField: nil,
}, },
}, },
idAllocator: idAllocator, idAllocator: idAllocator,
segIDAssigner: segAllocator, chMgr: chMgr,
chMgr: chMgr, chTicker: ticker,
chTicker: ticker,
} }
// don't support specify partition name if use partition key // don't support specify partition name if use partition key
@ -4060,12 +4020,8 @@ func TestDefaultPartition(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc) dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
factory := newSimpleMockMsgStreamFactory() chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID) pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -4081,12 +4037,6 @@ func TestDefaultPartition(t *testing.T) {
_ = idAllocator.Start() _ = idAllocator.Start()
defer idAllocator.Close() defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
nb := 10 nb := 10
fieldID := common.StartOfUserFieldID fieldID := common.StartOfUserFieldID
fieldDatas := make([]*schemapb.FieldData, 0) fieldDatas := make([]*schemapb.FieldData, 0)
@ -4127,13 +4077,11 @@ func TestDefaultPartition(t *testing.T) {
UpsertCnt: 0, UpsertCnt: 0,
Timestamp: 0, Timestamp: 0,
}, },
idAllocator: idAllocator, idAllocator: idAllocator,
segIDAssigner: segAllocator, chMgr: chMgr,
chMgr: chMgr, vChannels: nil,
chTicker: ticker, pChannels: nil,
vChannels: nil, schema: nil,
pChannels: nil,
schema: nil,
} }
it.insertMsg.PartitionName = "" it.insertMsg.PartitionName = ""
@ -4167,10 +4115,9 @@ func TestDefaultPartition(t *testing.T) {
IdField: nil, IdField: nil,
}, },
}, },
idAllocator: idAllocator, idAllocator: idAllocator,
segIDAssigner: segAllocator, chMgr: chMgr,
chMgr: chMgr, chTicker: ticker,
chTicker: ticker,
} }
ut.req.PartitionName = "" ut.req.PartitionName = ""

View File

@ -17,7 +17,6 @@ package proxy
import ( import (
"context" "context"
"fmt"
"strconv" "strconv"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
@ -55,7 +54,6 @@ type upsertTask struct {
rowIDs []int64 rowIDs []int64
result *milvuspb.MutationResult result *milvuspb.MutationResult
idAllocator *allocator.IDAllocator idAllocator *allocator.IDAllocator
segIDAssigner *segIDAssigner
collectionID UniqueID collectionID UniqueID
chMgr channelsMgr chMgr channelsMgr
chTicker channelsTimeTicker chTicker channelsTimeTicker
@ -445,189 +443,6 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
return nil return nil
} }
func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgPack) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", it.ID()))
defer tr.Elapse("insert execute done when insertExecute")
collectionName := it.upsertMsg.InsertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collectionName)
if err != nil {
return err
}
it.upsertMsg.InsertMsg.CollectionID = collID
it.upsertMsg.InsertMsg.BeginTimestamp = it.BeginTs()
it.upsertMsg.InsertMsg.EndTimestamp = it.EndTs()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collID))
getCacheDur := tr.RecordSpan()
_, err = it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil {
return err
}
getMsgStreamDur := tr.RecordSpan()
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Warn("get vChannels failed when insertExecute",
zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
log.Debug("send insert request to virtual channels when insertExecute",
zap.String("collection", it.req.GetCollectionName()),
zap.String("partition", it.req.GetPartitionName()),
zap.Int64("collection_id", collID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", it.ID()),
zap.Duration("get cache duration", getCacheDur),
zap.Duration("get msgStream duration", getMsgStreamDur))
// assign segmentID for insert data and repack data by segmentID
var insertMsgPack *msgstream.MsgPack
if it.partitionKeys == nil {
insertMsgPack, err = repackInsertData(it.TraceCtx(), channelNames, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner)
} else {
insertMsgPack, err = repackInsertDataWithPartitionKey(it.TraceCtx(), channelNames, it.partitionKeys, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner)
}
if err != nil {
log.Warn("assign segmentID and repack insert data failed when insertExecute",
zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
assignSegmentIDDur := tr.RecordSpan()
log.Debug("assign segmentID for insert data success when insertExecute",
zap.String("collectionName", it.req.CollectionName),
zap.Duration("assign segmentID duration", assignSegmentIDDur))
msgPack.Msgs = append(msgPack.Msgs, insertMsgPack.Msgs...)
log.Debug("Proxy Insert Execute done when upsert",
zap.String("collectionName", collectionName))
return nil
}
func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgPack) (err error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
collID := it.upsertMsg.DeleteMsg.CollectionID
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collID))
// hash primary keys to channels
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Warn("get vChannels failed when deleteExecute", zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs
it.upsertMsg.DeleteMsg.HashValues = typeutil.HashPK2Channels(it.upsertMsg.DeleteMsg.PrimaryKeys, channelNames)
// repack delete msg by dmChannel
result := make(map[uint32]msgstream.TsMsg)
collectionName := it.upsertMsg.DeleteMsg.CollectionName
collectionID := it.upsertMsg.DeleteMsg.CollectionID
partitionID := it.upsertMsg.DeleteMsg.PartitionID
partitionName := it.upsertMsg.DeleteMsg.PartitionName
proxyID := it.upsertMsg.DeleteMsg.Base.SourceID
for index, key := range it.upsertMsg.DeleteMsg.HashValues {
ts := it.upsertMsg.DeleteMsg.Timestamps[index]
_, ok := result[key]
if !ok {
msgid, err := it.idAllocator.AllocOne()
if err != nil {
return errors.Wrap(err, "failed to allocate MsgID for delete of upsert")
}
sliceRequest := &msgpb.DeleteRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
commonpbutil.WithTimeStamp(ts),
// id of upsertTask were set as ts in scheduler
// msgid of delete msg must be set
// or it will be seen as duplicated msg in mq
commonpbutil.WithMsgID(msgid),
commonpbutil.WithSourceID(proxyID),
),
CollectionID: collectionID,
PartitionID: partitionID,
CollectionName: collectionName,
PartitionName: partitionName,
PrimaryKeys: &schemapb.IDs{},
}
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
},
DeleteRequest: sliceRequest,
}
result[key] = deleteMsg
}
curMsg := result[key].(*msgstream.DeleteMsg)
curMsg.HashValues = append(curMsg.HashValues, it.upsertMsg.DeleteMsg.HashValues[index])
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
curMsg.NumRows++
curMsg.ShardName = channelNames[key]
}
// send delete request to log broker
deleteMsgPack := &msgstream.MsgPack{
BeginTs: it.upsertMsg.DeleteMsg.BeginTs(),
EndTs: it.upsertMsg.DeleteMsg.EndTs(),
}
for _, msg := range result {
if msg != nil {
deleteMsgPack.Msgs = append(deleteMsgPack.Msgs, msg)
}
}
msgPack.Msgs = append(msgPack.Msgs, deleteMsgPack.Msgs...)
log.Debug("Proxy Upsert deleteExecute done", zap.Int64("collection_id", collID),
zap.Strings("virtual_channels", channelNames), zap.Int64("taskID", it.ID()),
zap.Duration("prepare duration", tr.ElapseSpan()))
return nil
}
func (it *upsertTask) Execute(ctx context.Context) (err error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute")
defer sp.End()
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
stream, err := it.chMgr.getOrCreateDmlStream(ctx, it.collectionID)
if err != nil {
return err
}
msgPack := &msgstream.MsgPack{
BeginTs: it.BeginTs(),
EndTs: it.EndTs(),
}
err = it.insertExecute(ctx, msgPack)
if err != nil {
log.Warn("Fail to insertExecute", zap.Error(err))
return err
}
err = it.deleteExecute(ctx, msgPack)
if err != nil {
log.Warn("Fail to deleteExecute", zap.Error(err))
return err
}
tr.RecordSpan()
err = stream.Produce(ctx, msgPack)
if err != nil {
it.result.Status = merr.Status(err)
return err
}
sendMsgDur := tr.RecordSpan()
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
totalDur := tr.ElapseSpan()
log.Debug("Proxy Upsert Execute done", zap.Int64("taskID", it.ID()),
zap.Duration("total duration", totalDur))
return nil
}
func (it *upsertTask) PostExecute(ctx context.Context) error { func (it *upsertTask) PostExecute(ctx context.Context) error {
return nil return nil
} }

View File

@ -15,11 +15,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
type upsertTaskByStreamingService struct { func (ut *upsertTask) Execute(ctx context.Context) error {
*upsertTask
}
func (ut *upsertTaskByStreamingService) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute") ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute")
defer sp.End() defer sp.End()
log := log.Ctx(ctx).With(zap.String("collectionName", ut.req.CollectionName)) log := log.Ctx(ctx).With(zap.String("collectionName", ut.req.CollectionName))
@ -46,7 +42,7 @@ func (ut *upsertTaskByStreamingService) Execute(ctx context.Context) error {
return nil return nil
} }
func (ut *upsertTaskByStreamingService) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) { func (ut *upsertTask) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", ut.ID())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", ut.ID()))
defer tr.Elapse("insert execute done when insertExecute") defer tr.Elapse("insert execute done when insertExecute")
@ -93,7 +89,7 @@ func (ut *upsertTaskByStreamingService) packInsertMessage(ctx context.Context) (
return msgs, nil return msgs, nil
} }
func (it *upsertTaskByStreamingService) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) { func (it *upsertTask) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
collID := it.upsertMsg.DeleteMsg.CollectionID collID := it.upsertMsg.DeleteMsg.CollectionID
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs

View File

@ -2286,184 +2286,6 @@ func checkDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
return nil return nil
} }
func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.MsgStream, request interface{ GetBase() *commonpb.MsgBase }) {
if replicateMsgStream == nil || request == nil {
log.Ctx(ctx).Warn("replicate msg stream or request is nil", zap.Any("request", request))
return
}
msgBase := request.GetBase()
ts := msgBase.GetTimestamp()
if msgBase.GetReplicateInfo().GetIsReplicate() {
ts = msgBase.GetReplicateInfo().GetMsgTimestamp()
}
getBaseMsg := func(ctx context.Context, ts uint64) msgstream.BaseMsg {
return msgstream.BaseMsg{
Ctx: ctx,
HashValues: []uint32{0},
BeginTimestamp: ts,
EndTimestamp: ts,
}
}
var tsMsg msgstream.TsMsg
switch r := request.(type) {
case *milvuspb.AlterCollectionRequest:
tsMsg = &msgstream.AlterCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterCollectionRequest: r,
}
case *milvuspb.AlterCollectionFieldRequest:
tsMsg = &msgstream.AlterCollectionFieldMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterCollectionFieldRequest: r,
}
case *milvuspb.RenameCollectionRequest:
tsMsg = &msgstream.RenameCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
RenameCollectionRequest: r,
}
case *milvuspb.CreateDatabaseRequest:
tsMsg = &msgstream.CreateDatabaseMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateDatabaseRequest: r,
}
case *milvuspb.DropDatabaseRequest:
tsMsg = &msgstream.DropDatabaseMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropDatabaseRequest: r,
}
case *milvuspb.AlterDatabaseRequest:
tsMsg = &msgstream.AlterDatabaseMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterDatabaseRequest: r,
}
case *milvuspb.FlushRequest:
tsMsg = &msgstream.FlushMsg{
BaseMsg: getBaseMsg(ctx, ts),
FlushRequest: r,
}
case *milvuspb.LoadCollectionRequest:
tsMsg = &msgstream.LoadCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
LoadCollectionRequest: r,
}
case *milvuspb.ReleaseCollectionRequest:
tsMsg = &msgstream.ReleaseCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
ReleaseCollectionRequest: r,
}
case *milvuspb.CreateIndexRequest:
tsMsg = &msgstream.CreateIndexMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateIndexRequest: r,
}
case *milvuspb.DropIndexRequest:
tsMsg = &msgstream.DropIndexMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropIndexRequest: r,
}
case *milvuspb.LoadPartitionsRequest:
tsMsg = &msgstream.LoadPartitionsMsg{
BaseMsg: getBaseMsg(ctx, ts),
LoadPartitionsRequest: r,
}
case *milvuspb.ReleasePartitionsRequest:
tsMsg = &msgstream.ReleasePartitionsMsg{
BaseMsg: getBaseMsg(ctx, ts),
ReleasePartitionsRequest: r,
}
case *milvuspb.AlterIndexRequest:
tsMsg = &msgstream.AlterIndexMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterIndexRequest: r,
}
case *milvuspb.CreateCredentialRequest:
tsMsg = &msgstream.CreateUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateCredentialRequest: r,
}
case *milvuspb.UpdateCredentialRequest:
tsMsg = &msgstream.UpdateUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
UpdateCredentialRequest: r,
}
case *milvuspb.DeleteCredentialRequest:
tsMsg = &msgstream.DeleteUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
DeleteCredentialRequest: r,
}
case *milvuspb.CreateRoleRequest:
tsMsg = &msgstream.CreateRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateRoleRequest: r,
}
case *milvuspb.DropRoleRequest:
tsMsg = &msgstream.DropRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropRoleRequest: r,
}
case *milvuspb.OperateUserRoleRequest:
tsMsg = &msgstream.OperateUserRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperateUserRoleRequest: r,
}
case *milvuspb.OperatePrivilegeRequest:
tsMsg = &msgstream.OperatePrivilegeMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperatePrivilegeRequest: r,
}
case *milvuspb.OperatePrivilegeV2Request:
tsMsg = &msgstream.OperatePrivilegeV2Msg{
BaseMsg: getBaseMsg(ctx, ts),
OperatePrivilegeV2Request: r,
}
case *milvuspb.CreatePrivilegeGroupRequest:
tsMsg = &msgstream.CreatePrivilegeGroupMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreatePrivilegeGroupRequest: r,
}
case *milvuspb.DropPrivilegeGroupRequest:
tsMsg = &msgstream.DropPrivilegeGroupMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropPrivilegeGroupRequest: r,
}
case *milvuspb.OperatePrivilegeGroupRequest:
tsMsg = &msgstream.OperatePrivilegeGroupMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperatePrivilegeGroupRequest: r,
}
case *milvuspb.CreateAliasRequest:
tsMsg = &msgstream.CreateAliasMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateAliasRequest: r,
}
case *milvuspb.DropAliasRequest:
tsMsg = &msgstream.DropAliasMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropAliasRequest: r,
}
case *milvuspb.AlterAliasRequest:
tsMsg = &msgstream.AlterAliasMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterAliasRequest: r,
}
default:
log.Warn("unknown request", zap.Any("request", request))
return
}
msgPack := &msgstream.MsgPack{
BeginTs: ts,
EndTs: ts,
Msgs: []msgstream.TsMsg{tsMsg},
}
msgErr := replicateMsgStream.Produce(ctx, msgPack)
// ignore the error if the msg stream failed to produce the msg,
// because it can be manually fixed in this error
if msgErr != nil {
log.Warn("send replicate msg failed", zap.Any("pack", msgPack), zap.Error(msgErr))
}
}
func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) { func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) {
if globalMetaCache != nil { if globalMetaCache != nil {
return globalMetaCache.GetCollectionSchema(ctx, dbName, colName) return globalMetaCache.GetCollectionSchema(ctx, dbName, colName)

View File

@ -2302,55 +2302,6 @@ func Test_validateMaxCapacityPerRow(t *testing.T) {
}) })
} }
func TestSendReplicateMessagePack(t *testing.T) {
ctx := context.Background()
mockStream := msgstream.NewMockMsgStream(t)
t.Run("empty case", func(t *testing.T) {
SendReplicateMessagePack(ctx, nil, nil)
})
t.Run("produce fail", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("produce error")).Once()
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true,
MsgTimestamp: 100,
}},
})
})
t.Run("unknown request", func(t *testing.T) {
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ListDatabasesRequest{})
})
t.Run("normal case", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterCollectionFieldRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.RenameCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.FlushRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleaseCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateIndexRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropIndexRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadPartitionsRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleasePartitionsRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateCredentialRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DeleteCredentialRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateRoleRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropRoleRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.OperateUserRoleRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.OperatePrivilegeRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateAliasRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropAliasRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterAliasRequest{})
})
}
func TestAppendUserInfoForRPC(t *testing.T) { func TestAppendUserInfoForRPC(t *testing.T) {
ctx := GetContext(context.Background(), "root:123456") ctx := GetContext(context.Background(), "root:123456")
ctx = AppendUserInfoForRPC(ctx) ctx = AppendUserInfoForRPC(ctx)

View File

@ -48,12 +48,10 @@ import (
"github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/searchutil/optimizers" "github.com/milvus-io/milvus/internal/util/searchutil/optimizers"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
@ -136,8 +134,6 @@ type shardDelegator struct {
// stream delete buffer // stream delete buffer
deleteMut sync.RWMutex deleteMut sync.RWMutex
deleteBuffer deletebuffer.DeleteBuffer[*deletebuffer.Item] deleteBuffer deletebuffer.DeleteBuffer[*deletebuffer.Item]
// dispatcherClient msgdispatcher.Client
factory msgstream.Factory
sf conc.Singleflight[struct{}] sf conc.Singleflight[struct{}]
loader segments.Loader loader segments.Loader
@ -931,7 +927,7 @@ func (sd *shardDelegator) speedupGuranteeTS(
// when 1. streaming service is disable, // when 1. streaming service is disable,
// 2. consistency level is not strong, // 2. consistency level is not strong,
// 3. cannot speed iterator, because current client of milvus doesn't support shard level mvcc. // 3. cannot speed iterator, because current client of milvus doesn't support shard level mvcc.
if !streamingutil.IsStreamingServiceEnabled() || isIterator || cl != commonpb.ConsistencyLevel_Strong || mvccTS != 0 { if isIterator || cl != commonpb.ConsistencyLevel_Strong || mvccTS != 0 {
return guaranteeTS return guaranteeTS
} }
// use the mvcc timestamp of the wal as the guarantee timestamp to make fast strong consistency search. // use the mvcc timestamp of the wal as the guarantee timestamp to make fast strong consistency search.
@ -1145,8 +1141,7 @@ func (sd *shardDelegator) loadPartitionStats(ctx context.Context, partStatsVersi
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized. // NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64, func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader, workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
queryView *channelQueryView, queryView *channelQueryView,
) (ShardDelegator, error) { ) (ShardDelegator, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
@ -1184,7 +1179,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
pkOracle: pkoracle.NewPkOracle(), pkOracle: pkoracle.NewPkOracle(),
latestTsafe: atomic.NewUint64(startTs), latestTsafe: atomic.NewUint64(startTs),
loader: loader, loader: loader,
factory: factory,
queryHook: queryHook, queryHook: queryHook,
chunkManager: chunkManager, chunkManager: chunkManager,
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot), partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),

View File

@ -19,7 +19,6 @@ package delegator
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"runtime" "runtime"
"time" "time"
@ -40,8 +39,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
mqcommon "github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
@ -771,120 +768,6 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
return nil return nil
} }
func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) {
stream, err := sd.factory.NewTtMsgStream(ctx)
if err != nil {
return nil, nil, err
}
defer stream.Close()
vchannelName := position.ChannelName
pChannelName := funcutil.ToPhysicalChannel(vchannelName)
position.ChannelName = pChannelName
ts, _ := tsoutil.ParseTS(position.Timestamp)
// Random the subname in case we trying to load same delta at the same time
subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", paramtable.GetNodeID(), sd.collectionID, rand.Int())
log.Info("from dml check point load delete", zap.Any("position", position), zap.String("vChannel", vchannelName), zap.String("subName", subName), zap.Time("positionTs", ts))
err = stream.AsConsumer(context.TODO(), []string{pChannelName}, subName, mqcommon.SubscriptionPositionUnknown)
if err != nil {
return nil, stream.Close, err
}
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}, false)
if err != nil {
return nil, stream.Close, err
}
dispatcher := msgstream.NewSimpleMsgDispatcher(stream, func(pm msgstream.ConsumeMsg) bool {
if pm.GetType() != commonpb.MsgType_Delete || pm.GetVChannel() != vchannelName {
return false
}
return true
})
return dispatcher.Chan(), dispatcher.Close, nil
}
// Only used in test.
func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) {
log := sd.getLogger(ctx).With(
zap.String("channel", position.ChannelName),
zap.Int64("segmentID", candidate.ID()),
)
pChannelName := funcutil.ToPhysicalChannel(position.ChannelName)
var ch <-chan *msgstream.MsgPack
var closer func()
var err error
ch, closer, err = sd.createStreamFromMsgStream(ctx, position)
if closer != nil {
defer closer()
}
if err != nil {
return nil, err
}
start := time.Now()
result := &storage.DeleteData{}
hasMore := true
for hasMore {
select {
case <-ctx.Done():
log.Debug("read delta msg from seek position done", zap.Error(ctx.Err()))
return nil, ctx.Err()
case msgPack, ok := <-ch:
if !ok {
err = fmt.Errorf("stream channel closed, pChannelName=%v, msgID=%v", pChannelName, position.GetMsgID())
log.Warn("fail to read delta msg",
zap.String("pChannelName", pChannelName),
zap.Binary("msgID", position.GetMsgID()),
zap.Error(err),
)
return nil, err
}
if msgPack == nil {
continue
}
for _, tsMsg := range msgPack.Msgs {
if tsMsg.Type() == commonpb.MsgType_Delete {
dmsg := tsMsg.(*msgstream.DeleteMsg)
if dmsg.CollectionID != sd.collectionID || (dmsg.GetPartitionID() != common.AllPartitionsID && dmsg.GetPartitionID() != candidate.Partition()) {
continue
}
pks := storage.ParseIDs2PrimaryKeys(dmsg.GetPrimaryKeys())
batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt()
for idx := 0; idx < len(pks); idx += batchSize {
endIdx := idx + batchSize
if endIdx > len(pks) {
endIdx = len(pks)
}
lc := storage.NewBatchLocationsCache(pks[idx:endIdx])
hits := candidate.BatchPkExist(lc)
for i, hit := range hits {
if hit {
result.Pks = append(result.Pks, pks[idx+i])
result.Tss = append(result.Tss, dmsg.Timestamps[idx+i])
}
}
}
}
}
// reach safe ts
if safeTs <= msgPack.EndPositions[0].GetTimestamp() {
hasMore = false
}
}
}
log.Info("successfully read delete from stream ", zap.Duration("time spent", time.Since(start)))
return result, nil
}
// ReleaseSegments releases segments local or remotely depending on the target node. // ReleaseSegments releases segments local or remotely depending on the target node.
func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error { func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
log := sd.getLogger(ctx) log := sd.getLogger(ctx)

View File

@ -192,11 +192,7 @@ func (s *DelegatorDataSuite) genCollectionWithFunction() {
}}, }},
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)}) }, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{ delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err) s.NoError(err)
s.delegator = delegator.(*shardDelegator) s.delegator = delegator.(*shardDelegator)
} }
@ -214,11 +210,7 @@ func (s *DelegatorDataSuite) SetupTest() {
s.rootPath = s.Suite.T().Name() s.rootPath = s.Suite.T().Name()
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath) chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background()) s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{ delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator) sd, ok := delegator.(*shardDelegator)
s.Require().True(ok) s.Require().True(ok)
@ -806,11 +798,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
s.workerManager, s.workerManager,
s.manager, s.manager,
s.loader, s.loader,
&msgstream.MockMqFactory{ 10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err) s.NoError(err)
growing0 := segments.NewMockSegment(s.T()) growing0 := segments.NewMockSegment(s.T())
@ -1524,39 +1512,6 @@ func (s *DelegatorDataSuite) TestLevel0Deletions() {
s.Empty(pks) s.Empty(pks)
} }
func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.ConsumeMsgPack, 10)
s.mq.EXPECT().Chan().Return(ch)
oracle := pkoracle.NewBloomFilterSet(1, 1, commonpb.SegmentState_Sealed)
oracle.UpdateBloomFilter([]storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)})
baseMsg := &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete}
datas := []*msgstream.MsgPack{
{EndTs: 10, EndPositions: []*msgpb.MsgPosition{{Timestamp: 10}}, Msgs: []msgstream.TsMsg{
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: 1, PrimaryKeys: storage.ParseInt64s2IDs(1), Timestamps: []uint64{1}}},
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: -1, PrimaryKeys: storage.ParseInt64s2IDs(2), Timestamps: []uint64{5}}},
// invalid msg because partition wrong
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: 2, PrimaryKeys: storage.ParseInt64s2IDs(1), Timestamps: []uint64{10}}},
}},
}
for _, data := range datas {
ch <- msgstream.BuildConsumeMsgPack(data)
}
result, err := s.delegator.readDeleteFromMsgstream(ctx, &msgpb.MsgPosition{Timestamp: 0}, 10, oracle)
s.NoError(err)
s.Equal(2, len(result.Pks))
}
func (s *DelegatorDataSuite) TestDelegatorData_ExcludeSegments() { func (s *DelegatorDataSuite) TestDelegatorData_ExcludeSegments() {
s.delegator.AddExcludedSegments(map[int64]uint64{ s.delegator.AddExcludedSegments(map[int64]uint64{
1: 3, 1: 3,

View File

@ -19,6 +19,7 @@ package delegator
import ( import (
"context" "context"
"io" "io"
"os"
"testing" "testing"
"time" "time"
@ -33,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
@ -51,6 +53,12 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
func TestMain(m *testing.M) {
streaming.SetupNoopWALForTest()
os.Exit(m.Run())
}
type DelegatorSuite struct { type DelegatorSuite struct {
suite.Suite suite.Suite
@ -164,11 +172,7 @@ func (s *DelegatorSuite) SetupTest() {
var err error var err error
// s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader) // s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader)
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{ s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -203,11 +207,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
}}, }},
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)}) }, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, &msgstream.MockMqFactory{ _, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Error(err) s.Error(err)
}) })
@ -246,11 +246,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
}}, }},
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)}) }, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, &msgstream.MockMqFactory{ _, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err) s.NoError(err)
}) })
} }
@ -1410,11 +1406,7 @@ func (s *DelegatorSuite) TestUpdateSchema() {
func (s *DelegatorSuite) ResetDelegator() { func (s *DelegatorSuite) ResetDelegator() {
var err error var err error
s.delegator.Close() s.delegator.Close()
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{ s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
} }

View File

@ -151,11 +151,7 @@ func (s *StreamingForwardSuite) SetupTest() {
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath) chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background()) s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{ delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator) sd, ok := delegator.(*shardDelegator)
@ -394,11 +390,7 @@ func (s *GrowingMergeL0Suite) SetupTest() {
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath) chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background()) s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{ delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator) sd, ok := delegator.(*shardDelegator)

View File

@ -65,7 +65,6 @@ import (
"github.com/milvus-io/milvus/internal/util/searchutil/scheduler" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler"
"github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/config" "github.com/milvus-io/milvus/pkg/v2/config"
@ -528,11 +527,7 @@ func (node *QueryNode) Init() error {
node.manager = segments.NewManager() node.manager = segments.NewManager()
node.loader = segments.NewLoader(node.ctx, node.manager, node.chunkManager) node.loader = segments.NewLoader(node.ctx, node.manager, node.chunkManager)
node.manager.SetLoader(node.loader) node.manager.SetLoader(node.loader)
if streamingutil.IsStreamingServiceEnabled() { node.dispClient = msgdispatcher.NewClientWithIncludeSkipWhenSplit(streaming.NewDelegatorMsgstreamFactory(), typeutil.QueryNodeRole, node.GetNodeID())
node.dispClient = msgdispatcher.NewClientWithIncludeSkipWhenSplit(streaming.NewDelegatorMsgstreamFactory(), typeutil.QueryNodeRole, node.GetNodeID())
} else {
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, node.GetNodeID())
}
// init pipeline manager // init pipeline manager
node.pipelineManager = pipeline.NewManager(node.manager, node.dispClient, node.delegators) node.pipelineManager = pipeline.NewManager(node.manager, node.dispClient, node.delegators)

View File

@ -267,7 +267,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
node.clusterManager, node.clusterManager,
node.manager, node.manager,
node.loader, node.loader,
node.factory,
channel.GetSeekPosition().GetTimestamp(), channel.GetSeekPosition().GetTimestamp(),
node.queryHook, node.queryHook,
node.chunkManager, node.chunkManager,

View File

@ -49,12 +49,12 @@ import (
"github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb" "github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -68,7 +68,6 @@ import (
type ServiceSuite struct { type ServiceSuite struct {
suite.Suite suite.Suite
// Data // Data
msgChan chan *msgstream.ConsumeMsgPack
collectionID int64 collectionID int64
collectionName string collectionName string
schema *schemapb.CollectionSchema schema *schemapb.CollectionSchema
@ -92,8 +91,7 @@ type ServiceSuite struct {
chunkManagerFactory *storage.ChunkManagerFactory chunkManagerFactory *storage.ChunkManagerFactory
// Mock // Mock
factory *dependency.MockFactory factory *dependency.MockFactory
msgStream *msgstream.MockMsgStream
} }
func (suite *ServiceSuite) SetupSuite() { func (suite *ServiceSuite) SetupSuite() {
@ -129,7 +127,6 @@ func (suite *ServiceSuite) SetupTest() {
ctx := context.Background() ctx := context.Background()
// init mock // init mock
suite.factory = dependency.NewMockFactory(suite.T()) suite.factory = dependency.NewMockFactory(suite.T())
suite.msgStream = msgstream.NewMockMsgStream(suite.T())
// TODO:: cpp chunk manager not support local chunk manager // TODO:: cpp chunk manager not support local chunk manager
paramtable.Get().Save(paramtable.Get().LocalStorageCfg.Path.Key, suite.T().TempDir()) paramtable.Get().Save(paramtable.Get().LocalStorageCfg.Path.Key, suite.T().TempDir())
// suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test")) // suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test"))
@ -315,13 +312,6 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema), IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
} }
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()
// watchDmChannels // watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req) status, err := suite.node.WatchDmChannels(ctx, req)
suite.NoError(err) suite.NoError(err)
@ -367,13 +357,6 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema), IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
} }
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()
// watchDmChannels // watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req) status, err := suite.node.WatchDmChannels(ctx, req)
suite.NoError(err) suite.NoError(err)
@ -2437,6 +2420,13 @@ func TestQueryNodeService(t *testing.T) {
local.EXPECT().GetLatestMVCCTimestampIfLocal(mock.Anything, mock.Anything).Return(0, nil).Maybe() local.EXPECT().GetLatestMVCCTimestampIfLocal(mock.Anything, mock.Anything).Return(0, nil).Maybe()
local.EXPECT().GetMetricsIfLocal(mock.Anything).Return(&types.StreamingNodeMetrics{}, nil).Maybe() local.EXPECT().GetMetricsIfLocal(mock.Anything).Return(&types.StreamingNodeMetrics{}, nil).Maybe()
wal.EXPECT().Local().Return(local).Maybe() wal.EXPECT().Local().Return(local).Maybe()
wal.EXPECT().WALName().Return(rmq.WALName).Maybe()
scanner := mock_streaming.NewMockScanner(t)
scanner.EXPECT().Done().Return(make(chan struct{})).Maybe()
scanner.EXPECT().Error().Return(nil).Maybe()
scanner.EXPECT().Close().Return().Maybe()
wal.EXPECT().Read(mock.Anything, mock.Anything).Return(scanner).Maybe()
streaming.SetWALForTest(wal) streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest() defer streaming.RecoverWALForTest()

View File

@ -1,98 +0,0 @@
package rootcoord
import (
"context"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
var _ task = (*broadcastTask)(nil)
// BroadcastTask is used to implement the broadcast operation based on the msgstream
// by using the streaming service interface.
// msgstream will be deprecated since 2.6.0 with streaming service, so those code will be removed in the future version.
type broadcastTask struct {
baseTask
msgs []message.MutableMessage // The message wait for broadcast
walName string
resultFuture *syncutil.Future[types.AppendResponses]
}
func (b *broadcastTask) Execute(ctx context.Context) error {
result := types.NewAppendResponseN(len(b.msgs))
defer func() {
b.resultFuture.Set(result)
}()
for idx, msg := range b.msgs {
tsMsg, err := adaptor.NewMsgPackFromMutableMessageV1(msg)
tsMsg.SetTs(b.ts) // overwrite the ts.
if err != nil {
result.FillResponseAtIdx(types.AppendResponse{Error: err}, idx)
return err
}
pchannel := funcutil.ToPhysicalChannel(msg.VChannel())
msgID, err := b.core.chanTimeTick.broadcastMarkDmlChannels([]string{pchannel}, &msgstream.MsgPack{
BeginTs: b.ts,
EndTs: b.ts,
Msgs: []msgstream.TsMsg{tsMsg},
})
if err != nil {
result.FillResponseAtIdx(types.AppendResponse{Error: err}, idx)
continue
}
result.FillResponseAtIdx(types.AppendResponse{
AppendResult: &types.AppendResult{
MessageID: adaptor.MustGetMessageIDFromMQWrapperIDBytes(b.walName, msgID[pchannel]),
TimeTick: b.ts,
},
}, idx)
}
return result.UnwrapFirstError()
}
func newMsgStreamAppendOperator(c *Core) *msgstreamAppendOperator {
return &msgstreamAppendOperator{
core: c,
walName: util.MustSelectWALName(),
}
}
// msgstreamAppendOperator the code of streamingcoord to make broadcast available on the legacy msgstream.
// Because msgstream is bound to the rootcoord task, so we transfer each broadcast operation into a ddl task.
// to make sure the timetick rule.
// The Msgstream will be deprecated since 2.6.0, so we make a single module to hold it.
type msgstreamAppendOperator struct {
core *Core
walName string
}
// AppendMessages implements the AppendOperator interface for broadcaster service at streaming service.
func (m *msgstreamAppendOperator) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
t := &broadcastTask{
baseTask: newBaseTask(ctx, m.core),
msgs: msgs,
walName: m.walName,
resultFuture: syncutil.NewFuture[types.AppendResponses](),
}
if err := m.core.scheduler.AddTask(t); err != nil {
resp := types.NewAppendResponseN(len(msgs))
resp.FillAllError(err)
return resp
}
result, err := t.resultFuture.GetWithContext(ctx)
if err != nil {
resp := types.NewAppendResponseN(len(msgs))
resp.FillAllError(err)
return resp
}
return result
}

View File

@ -44,7 +44,6 @@ import (
kvmetastore "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord" kvmetastore "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/metastore/model"
streamingcoord "github.com/milvus-io/milvus/internal/streamingcoord/server" streamingcoord "github.com/milvus-io/milvus/internal/streamingcoord/server"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
tso2 "github.com/milvus-io/milvus/internal/tso" tso2 "github.com/milvus-io/milvus/internal/tso"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
@ -694,11 +693,6 @@ func (c *Core) startInternal() error {
c.UpdateStateCode(commonpb.StateCode_Healthy) c.UpdateStateCode(commonpb.StateCode_Healthy)
sessionutil.SaveServerInfo(typeutil.MixCoordRole, c.session.GetServerID()) sessionutil.SaveServerInfo(typeutil.MixCoordRole, c.session.GetServerID())
log.Info("rootcoord startup successfully") log.Info("rootcoord startup successfully")
// regster the core as a appendoperator for broadcast service.
// TODO: should be removed at 2.6.0.
// Add the wal accesser to the broadcaster registry for making broadcast operation.
registry.Register(registry.AppendOperatorTypeMsgstream, newMsgStreamAppendOperator(c))
return nil return nil
} }

View File

@ -12,7 +12,6 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/client/assignment" "github.com/milvus-io/milvus/internal/streamingcoord/client/assignment"
"github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast" "github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker" "github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker"
streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor" streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
@ -76,11 +75,8 @@ func NewClient(etcdCli *clientv3.Client) Client {
dialOptions..., dialOptions...,
) )
}) })
var assignmentServiceImpl *assignment.AssignmentServiceImpl assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient)
if streamingutil.IsStreamingServiceEnabled() { assignmentServiceImpl := assignment.NewAssignmentService(assignmentService)
assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient)
assignmentServiceImpl = assignment.NewAssignmentService(assignmentService)
}
broadcastService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordBroadcastServiceClient) broadcastService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordBroadcastServiceClient)
return &clientImpl{ return &clientImpl{
conn: conn, conn: conn,

View File

@ -1,14 +0,0 @@
package broadcaster
import (
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/util/streamingutil"
)
// NewAppendOperator creates an append operator to handle the incoming messages for broadcaster.
func NewAppendOperator() AppendOperator {
if streamingutil.IsStreamingServiceEnabled() {
return streaming.WAL()
}
return nil
}

View File

@ -3,7 +3,6 @@ package broadcaster
import ( import (
"context" "context"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
) )
@ -19,4 +18,9 @@ type Broadcaster interface {
Close() Close()
} }
type AppendOperator = registry.AppendOperator // AppendOperator is used to append messages, there's only two implement of this interface:
// 1. streaming.WAL()
// 2. old msgstream interface [deprecated]
type AppendOperator interface {
AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses
}

View File

@ -22,7 +22,6 @@ import (
func RecoverBroadcaster( func RecoverBroadcaster(
ctx context.Context, ctx context.Context,
appendOperator *syncutil.Future[AppendOperator],
) (Broadcaster, error) { ) (Broadcaster, error) {
tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx) tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx)
if err != nil { if err != nil {
@ -38,7 +37,6 @@ func RecoverBroadcaster(
backoffChan: make(chan *pendingBroadcastTask), backoffChan: make(chan *pendingBroadcastTask),
pendingChan: make(chan *pendingBroadcastTask), pendingChan: make(chan *pendingBroadcastTask),
workerChan: make(chan *pendingBroadcastTask), workerChan: make(chan *pendingBroadcastTask),
appendOperator: appendOperator,
} }
go b.execute() go b.execute()
return b, nil return b, nil
@ -54,7 +52,6 @@ type broadcasterImpl struct {
pendingChan chan *pendingBroadcastTask pendingChan chan *pendingBroadcastTask
backoffChan chan *pendingBroadcastTask backoffChan chan *pendingBroadcastTask
workerChan chan *pendingBroadcastTask workerChan chan *pendingBroadcastTask
appendOperator *syncutil.Future[AppendOperator] // TODO: we can remove those lazy future in 2.6.0, by remove the msgstream broadcaster.
} }
// Broadcast broadcasts the message to all channels. // Broadcast broadcasts the message to all channels.
@ -140,14 +137,6 @@ func (b *broadcasterImpl) execute() {
b.Logger().Info("broadcaster execute exit") b.Logger().Info("broadcaster execute exit")
}() }()
// Wait for appendOperator ready
appendOperator, err := b.appendOperator.GetWithContext(b.backgroundTaskNotifier.Context())
if err != nil {
b.Logger().Info("broadcaster is closed before appendOperator ready")
return
}
b.Logger().Info("broadcaster appendOperator ready, begin to start workers and dispatch")
// Start n workers to handle the broadcast task. // Start n workers to handle the broadcast task.
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
@ -156,7 +145,7 @@ func (b *broadcasterImpl) execute() {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
b.worker(i, appendOperator) b.worker(i)
}() }()
} }
defer wg.Wait() defer wg.Wait()
@ -205,7 +194,7 @@ func (b *broadcasterImpl) dispatch() {
} }
} }
func (b *broadcasterImpl) worker(no int, appendOperator AppendOperator) { func (b *broadcasterImpl) worker(no int) {
logger := b.Logger().With(zap.Int("workerNo", no)) logger := b.Logger().With(zap.Int("workerNo", no))
defer func() { defer func() {
logger.Info("broadcaster worker exit") logger.Info("broadcaster worker exit")
@ -216,7 +205,7 @@ func (b *broadcasterImpl) worker(no int, appendOperator AppendOperator) {
case <-b.backgroundTaskNotifier.Context().Done(): case <-b.backgroundTaskNotifier.Context().Done():
return return
case task := <-b.workerChan: case task := <-b.workerChan:
if err := task.Execute(b.backgroundTaskNotifier.Context(), appendOperator); err != nil { if err := task.Execute(b.backgroundTaskNotifier.Context()); err != nil {
// If the task is not done, repush it into pendings and retry infinitely. // If the task is not done, repush it into pendings and retry infinitely.
select { select {
case <-b.backgroundTaskNotifier.Context().Done(): case <-b.backgroundTaskNotifier.Context().Done():

View File

@ -12,8 +12,9 @@ import (
"go.uber.org/atomic" "go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore" "github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
internaltypes "github.com/milvus-io/milvus/internal/types" internaltypes "github.com/milvus-io/milvus/internal/types"
@ -76,8 +77,8 @@ func TestBroadcaster(t *testing.T) {
resource.InitForTest(resource.OptStreamingCatalog(meta), resource.OptMixCoordClient(f)) resource.InitForTest(resource.OptStreamingCatalog(meta), resource.OptMixCoordClient(f))
fbc := syncutil.NewFuture[Broadcaster]() fbc := syncutil.NewFuture[Broadcaster]()
operator, appended := createOpeartor(t, fbc) appended := createOpeartor(t, fbc)
bc, err := RecoverBroadcaster(context.Background(), operator) bc, err := RecoverBroadcaster(context.Background())
fbc.Set(bc) fbc.Set(bc)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, bc) assert.NotNil(t, bc)
@ -135,10 +136,10 @@ func ack(broadcaster Broadcaster, broadcastID uint64, vchannel string) {
} }
} }
func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) (*syncutil.Future[AppendOperator], *atomic.Int64) { func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) *atomic.Int64 {
id := atomic.NewInt64(1) id := atomic.NewInt64(1)
appended := atomic.NewInt64(0) appended := atomic.NewInt64(0)
operator := mock_broadcaster.NewMockAppendOperator(t) operator := mock_streaming.NewMockWALAccesser(t)
f := func(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses { f := func(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
resps := types.AppendResponses{ resps := types.AppendResponses{
Responses: make([]types.AppendResponse, len(msgs)), Responses: make([]types.AppendResponse, len(msgs)),
@ -174,9 +175,8 @@ func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) (*s
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f) operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
fOperator := syncutil.NewFuture[AppendOperator]() streaming.SetWALForTest(operator)
fOperator.Set(operator) return appended
return fOperator, appended
} }
func createNewBroadcastMsg(vchannels []string, rks ...message.ResourceKey) message.BroadcastMutableMessage { func createNewBroadcastMsg(vchannels []string, rks ...message.ResourceKey) message.BroadcastMutableMessage {

View File

@ -1,42 +0,0 @@
package registry
import (
"context"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
type AppendOperatorType int
const (
AppendOperatorTypeMsgstream AppendOperatorType = iota + 1
AppendOperatorTypeStreaming
)
var localRegistry = make(map[AppendOperatorType]*syncutil.Future[AppendOperator])
// AppendOperator is used to append messages, there's only two implement of this interface:
// 1. streaming.WAL()
// 2. old msgstream interface
type AppendOperator interface {
AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses
}
func init() {
localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]()
localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]()
}
func Register(typ AppendOperatorType, op AppendOperator) {
localRegistry[typ].Set(op)
}
func GetAppendOperator() *syncutil.Future[AppendOperator] {
if streamingutil.IsStreamingServiceEnabled() {
return localRegistry[AppendOperatorTypeStreaming]
}
return localRegistry[AppendOperatorTypeMsgstream]
}

View File

@ -3,12 +3,7 @@
package registry package registry
import "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
func ResetRegistration() { func ResetRegistration() {
localRegistry = make(map[AppendOperatorType]*syncutil.Future[AppendOperator])
localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]()
localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]()
resetMessageAckCallbacks() resetMessageAckCallbacks()
resetMessageCheckCallbacks() resetMessageCheckCallbacks()
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
@ -49,7 +50,7 @@ type pendingBroadcastTask struct {
// Execute reexecute the task, return nil if the task is done, otherwise not done. // Execute reexecute the task, return nil if the task is done, otherwise not done.
// Execute can be repeated called until the task is done. // Execute can be repeated called until the task is done.
// Same semantics as the `Poll` operation in eventloop. // Same semantics as the `Poll` operation in eventloop.
func (b *pendingBroadcastTask) Execute(ctx context.Context, operator AppendOperator) error { func (b *pendingBroadcastTask) Execute(ctx context.Context) error {
if err := b.broadcastTask.InitializeRecovery(ctx); err != nil { if err := b.broadcastTask.InitializeRecovery(ctx); err != nil {
b.Logger().Warn("broadcast task initialize recovery failed", zap.Error(err)) b.Logger().Warn("broadcast task initialize recovery failed", zap.Error(err))
b.UpdateInstantWithNextBackOff() b.UpdateInstantWithNextBackOff()
@ -58,7 +59,7 @@ func (b *pendingBroadcastTask) Execute(ctx context.Context, operator AppendOpera
if len(b.pendingMessages) > 0 { if len(b.pendingMessages) > 0 {
b.Logger().Debug("broadcast task is polling to make sent...", zap.Int("pendingMessages", len(b.pendingMessages))) b.Logger().Debug("broadcast task is polling to make sent...", zap.Int("pendingMessages", len(b.pendingMessages)))
resps := operator.AppendMessages(ctx, b.pendingMessages...) resps := streaming.WAL().AppendMessages(ctx, b.pendingMessages...)
newPendings := make([]message.MutableMessage, 0) newPendings := make([]message.MutableMessage, 0)
for idx, resp := range resps.Responses { for idx, resp := range resps.Responses {
if resp.Error != nil { if resp.Error != nil {

View File

@ -9,7 +9,6 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/client/manager" "github.com/milvus-io/milvus/internal/streamingnode/client/manager"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/idalloc" "github.com/milvus-io/milvus/internal/util/idalloc"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil" "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
@ -55,10 +54,8 @@ func Init(opts ...optResourceInit) {
assertNotNil(newR.MixCoordClient()) assertNotNil(newR.MixCoordClient())
assertNotNil(newR.ETCD()) assertNotNil(newR.ETCD())
assertNotNil(newR.StreamingCatalog()) assertNotNil(newR.StreamingCatalog())
if streamingutil.IsStreamingServiceEnabled() { newR.streamingNodeManagerClient = manager.NewManagerClient(newR.etcdClient)
newR.streamingNodeManagerClient = manager.NewManagerClient(newR.etcdClient) assertNotNil(newR.StreamingNodeManagerClient())
assertNotNil(newR.StreamingNodeManagerClient())
}
r = newR r = newR
} }

View File

@ -10,11 +10,9 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer" "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
_ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy _ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster" "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/streamingcoord/server/service" "github.com/milvus-io/milvus/internal/streamingcoord/server/service"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/util" "github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
@ -53,27 +51,25 @@ func (s *Server) Start(ctx context.Context) (err error) {
// initBasicComponent initialize all underlying dependency for streamingcoord. // initBasicComponent initialize all underlying dependency for streamingcoord.
func (s *Server) initBasicComponent(ctx context.Context) (err error) { func (s *Server) initBasicComponent(ctx context.Context) (err error) {
futures := make([]*conc.Future[struct{}], 0) futures := make([]*conc.Future[struct{}], 0)
if streamingutil.IsStreamingServiceEnabled() { futures = append(futures, conc.Go(func() (struct{}, error) {
futures = append(futures, conc.Go(func() (struct{}, error) { s.logger.Info("start recovery balancer...")
s.logger.Info("start recovery balancer...") // Read new incoming topics from configuration, and register it into balancer.
// Read new incoming topics from configuration, and register it into balancer. newIncomingTopics := util.GetAllTopicsFromConfiguration()
newIncomingTopics := util.GetAllTopicsFromConfiguration() balancer, err := balancer.RecoverBalancer(ctx, newIncomingTopics.Collect()...)
balancer, err := balancer.RecoverBalancer(ctx, newIncomingTopics.Collect()...) if err != nil {
if err != nil { s.logger.Warn("recover balancer failed", zap.Error(err))
s.logger.Warn("recover balancer failed", zap.Error(err)) return struct{}{}, err
return struct{}{}, err }
} s.balancer.Set(balancer)
s.balancer.Set(balancer) snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer) s.logger.Info("recover balancer done")
s.logger.Info("recover balancer done") return struct{}{}, nil
return struct{}{}, nil }))
}))
}
// The broadcaster of msgstream is implemented on current streamingcoord to reduce the development complexity. // The broadcaster of msgstream is implemented on current streamingcoord to reduce the development complexity.
// So we need to recover it. // So we need to recover it.
futures = append(futures, conc.Go(func() (struct{}, error) { futures = append(futures, conc.Go(func() (struct{}, error) {
s.logger.Info("start recovery broadcaster...") s.logger.Info("start recovery broadcaster...")
broadcaster, err := broadcaster.RecoverBroadcaster(ctx, registry.GetAppendOperator()) broadcaster, err := broadcaster.RecoverBroadcaster(ctx)
if err != nil { if err != nil {
s.logger.Warn("recover broadcaster failed", zap.Error(err)) s.logger.Warn("recover broadcaster failed", zap.Error(err))
return struct{}{}, err return struct{}{}, err
@ -87,9 +83,7 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) {
// RegisterGRPCService register all grpc service to grpc server. // RegisterGRPCService register all grpc service to grpc server.
func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) { func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) {
if streamingutil.IsStreamingServiceEnabled() { streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService)
streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService)
}
streamingpb.RegisterStreamingCoordBroadcastServiceServer(grpcServer, s.broadcastService) streamingpb.RegisterStreamingCoordBroadcastServiceServer(grpcServer, s.broadcastService)
} }

View File

@ -84,15 +84,6 @@ type DataNodeComponent interface {
// SetEtcdClient set etcd client for DataNode // SetEtcdClient set etcd client for DataNode
SetEtcdClient(etcdClient *clientv3.Client) SetEtcdClient(etcdClient *clientv3.Client)
// SetMixCoordClient set SetMixCoordClient for DataNode
// `mixCoord` is a client of root coordinator.
//
// Return a generic error in status:
// If the mixCoord is nil or the mixCoord has been set before.
// Return nil in status:
// The mixCoord is not nil.
SetMixCoordClient(mixCoord MixCoordClient) error
} }
// DataCoordClient is the client interface for datacoord server // DataCoordClient is the client interface for datacoord server
@ -196,9 +187,6 @@ type ProxyComponent interface {
SetAddress(address string) SetAddress(address string)
GetAddress() string GetAddress() string
// SetEtcdClient set EtcdClient for Proxy
// `etcdClient` is a client of etcd
SetEtcdClient(etcdClient *clientv3.Client)
// SetMixCoordClient set MixCoord for Proxy // SetMixCoordClient set MixCoord for Proxy
// `mixCoord` is a client of mix coordinator. // `mixCoord` is a client of mix coordinator.