enhance: remove old arch non-streaming arch code (#43651)

issue: #41609

- remove all dml dead code at proxy
- remove dead code at l0_write_buffer
- remove msgstream dependency at proxy
- remove timetick reporter from proxy
- remove replicate stream implementation

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-08-06 14:41:40 +08:00 committed by GitHub
parent 6ae727775f
commit 5551d99425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
73 changed files with 611 additions and 4512 deletions

View File

@ -146,9 +146,7 @@ func GetMilvusRoles(args []string, flags *flag.FlagSet) *roles.MilvusRoles {
role.EnableProxy = true
role.EnableQueryNode = true
role.EnableDataNode = true
if streamingutil.IsStreamingServiceEnabled() {
role.EnableStreamingNode = true
}
role.EnableStreamingNode = true
role.Local = true
role.Embedded = serverType == typeutil.EmbeddedRole
case typeutil.MixCoordRole:

View File

@ -12,6 +12,7 @@ packages:
Utility:
Broadcast:
Local:
Scanner:
github.com/milvus-io/milvus/internal/streamingcoord/server/balancer:
interfaces:
Balancer:

View File

@ -1462,6 +1462,9 @@ func (s *Server) GetFlushState(ctx context.Context, req *datapb.GetFlushStateReq
for _, sid := range req.GetSegmentIDs() {
segment := s.meta.GetHealthySegment(ctx, sid)
// segment is nil if it was compacted, or it's an empty segment and is set to dropped
// TODO: Here's a dirty implementation, because a growing segment may cannot be seen right away by mixcoord,
// it can only be seen by streamingnode right away, so we need to check the flush state at streamingnode but not here.
// use timetick for GetFlushState in-future but not segment list.
if segment == nil || isFlushState(segment.GetState()) {
continue
}

View File

@ -32,13 +32,10 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
dn "github.com/milvus-io/milvus/internal/datanode"
mix "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
"github.com/milvus-io/milvus/internal/distributed/utils"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
_ "github.com/milvus-io/milvus/internal/util/grpcclient"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
@ -65,8 +62,6 @@ type Server struct {
factory dependency.Factory
serverID atomic.Int64
mixCoordClient func() (types.MixCoordClient, error)
}
// NewServer new DataNode grpc server
@ -77,9 +72,6 @@ func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error)
cancel: cancel,
factory: factory,
grpcErrChan: make(chan error),
mixCoordClient: func() (types.MixCoordClient, error) {
return mix.NewClient(ctx1)
},
}
s.serverID.Store(paramtable.GetNodeID())
@ -173,10 +165,6 @@ func (s *Server) SetEtcdClient(client *clientv3.Client) {
s.datanode.SetEtcdClient(client)
}
func (s *Server) SetMixCoordInterface(ms types.MixCoordClient) error {
return s.datanode.SetMixCoordClient(ms)
}
// Run initializes and starts Datanode's grpc service.
func (s *Server) Run() error {
if err := s.init(); err != nil {
@ -255,27 +243,6 @@ func (s *Server) init() error {
return err
}
if !streamingutil.IsStreamingServiceEnabled() {
// --- MixCoord Client ---
if s.mixCoordClient != nil {
log.Info("initializing MixCoord client for DataNode")
mixCoordClient, err := s.mixCoordClient()
if err != nil {
log.Error("failed to create new MixCoord client", zap.Error(err))
panic(err)
}
if err = componentutil.WaitForComponentHealthy(s.ctx, mixCoordClient, "MixCoord", 1000000, time.Millisecond*200); err != nil {
log.Error("failed to wait for MixCoord client to be ready", zap.Error(err))
panic(err)
}
log.Info("MixCoord client is ready for DataNode")
if err = s.SetMixCoordInterface(mixCoordClient); err != nil {
panic(err)
}
}
}
s.datanode.UpdateStateCode(commonpb.StateCode_Initializing)
if err := s.datanode.Init(); err != nil {

View File

@ -27,7 +27,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
@ -43,27 +42,10 @@ func Test_NewServer(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, server)
mockMixCoord := mocks.NewMockMixCoordClient(t)
mockMixCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
Status: merr.Success(),
SubcomponentStates: []*milvuspb.ComponentInfo{
{
StateCode: commonpb.StateCode_Healthy,
},
},
}, nil)
server.mixCoordClient = func() (types.MixCoordClient, error) {
return mockMixCoord, nil
}
t.Run("Run", func(t *testing.T) {
datanode := mocks.NewMockDataNode(t)
datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
datanode.EXPECT().SetAddress(mock.Anything).Return()
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
datanode.EXPECT().Register().Return(nil)
datanode.EXPECT().Init().Return(nil)
@ -191,26 +173,9 @@ func Test_Run(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, server)
mockRootCoord := mocks.NewMockMixCoordClient(t)
mockRootCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
Status: merr.Success(),
SubcomponentStates: []*milvuspb.ComponentInfo{
{
StateCode: commonpb.StateCode_Healthy,
},
},
}, nil)
server.mixCoordClient = func() (types.MixCoordClient, error) {
return mockRootCoord, nil
}
datanode := mocks.NewMockDataNode(t)
datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
datanode.EXPECT().SetAddress(mock.Anything).Return()
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
datanode.EXPECT().Init().Return(errors.New("mock err"))
server.datanode = datanode
@ -223,7 +188,6 @@ func Test_Run(t *testing.T) {
datanode = mocks.NewMockDataNode(t)
datanode.EXPECT().SetEtcdClient(mock.Anything).Return()
datanode.EXPECT().SetAddress(mock.Anything).Return()
datanode.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
datanode.EXPECT().UpdateStateCode(mock.Anything).Return()
datanode.EXPECT().Register().Return(nil)
datanode.EXPECT().Init().Return(nil)
@ -242,26 +206,9 @@ func TestIndexService(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, server)
mockRootCoord := mocks.NewMockMixCoordClient(t)
mockRootCoord.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{
StateCode: commonpb.StateCode_Healthy,
},
Status: merr.Success(),
SubcomponentStates: []*milvuspb.ComponentInfo{
{
StateCode: commonpb.StateCode_Healthy,
},
},
}, nil)
server.mixCoordClient = func() (types.MixCoordClient, error) {
return mockRootCoord, nil
}
dn := mocks.NewMockDataNode(t)
dn.EXPECT().SetEtcdClient(mock.Anything).Return()
dn.EXPECT().SetAddress(mock.Anything).Return()
dn.EXPECT().SetMixCoordClient(mock.Anything).Return(nil)
dn.EXPECT().UpdateStateCode(mock.Anything).Return()
dn.EXPECT().Register().Return(nil)
dn.EXPECT().Init().Return(nil)

View File

@ -449,7 +449,6 @@ func (s *Server) init() error {
return err
}
s.etcdCli = etcdCli
s.proxy.SetEtcdClient(s.etcdCli)
s.proxy.SetAddress(s.listenerManager.internalGrpcListener.Address())
errChan := make(chan error, 1)

View File

@ -201,7 +201,6 @@ func Test_NewServer(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -687,7 +686,6 @@ func Test_NewServer(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -821,7 +819,6 @@ func Test_NewServer_HTTPServer_Enabled(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -886,7 +883,6 @@ func Test_NewServer_TLS_TwoWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -914,7 +910,6 @@ func Test_NewServer_TLS_OneWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -938,7 +933,6 @@ func Test_NewServer_TLS_FileNotExisted(t *testing.T) {
mockProxy := server.proxy.(*mocks.MockProxy)
mockProxy.EXPECT().Stop().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetAddress(mock.Anything).Return()
@ -977,7 +971,6 @@ func Test_NewHTTPServer_TLS_TwoWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -1013,7 +1006,6 @@ func Test_NewHTTPServer_TLS_OneWay(t *testing.T) {
mockProxy.EXPECT().Init().Return(nil)
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()
@ -1046,7 +1038,6 @@ func Test_NewHTTPServer_TLS_FileNotExisted(t *testing.T) {
mockProxy := server.proxy.(*mocks.MockProxy)
mockProxy.EXPECT().Stop().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return().Maybe()
mockProxy.EXPECT().SetAddress(mock.Anything).Return().Maybe()
Params := &paramtable.Get().ProxyGrpcServerCfg
@ -1160,7 +1151,6 @@ func Test_Service_GracefulStop(t *testing.T) {
mockProxy.EXPECT().Start().Return(nil)
mockProxy.EXPECT().Stop().Return(nil)
mockProxy.EXPECT().Register().Return(nil)
mockProxy.EXPECT().SetEtcdClient(mock.Anything).Return()
mockProxy.EXPECT().GetRateLimiter().Return(nil, nil)
mockProxy.EXPECT().SetMixCoordClient(mock.Anything).Return()
mockProxy.EXPECT().UpdateStateCode(mock.Anything).Return()

View File

@ -4,7 +4,6 @@ import (
"context"
"time"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/options"
@ -18,8 +17,6 @@ var singleton WALAccesser = nil
func Init() {
c, _ := kvfactory.GetEtcdAndPath()
singleton = newWALAccesser(c)
// Add the wal accesser to the broadcaster registry for making broadcast operation.
registry.Register(registry.AppendOperatorTypeStreaming, singleton)
}
// Release releases the resources of the wal accesser.

View File

@ -18,7 +18,20 @@
package streaming
import kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
import (
"context"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/types/known/anypb"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
)
var expectErr = make(chan error, 10)
// SetWALForTest initializes the singleton of wal for test.
func SetWALForTest(w WALAccesser) {
@ -29,3 +42,146 @@ func RecoverWALForTest() {
c, _ := kvfactory.GetEtcdAndPath()
singleton = newWALAccesser(c)
}
func ExpectErrorOnce(err error) {
expectErr <- err
}
func SetupNoopWALForTest() {
singleton = &noopWALAccesser{}
}
type noopLocal struct{}
func (n *noopLocal) GetLatestMVCCTimestampIfLocal(ctx context.Context, vchannel string) (uint64, error) {
return 0, errors.New("not implemented")
}
func (n *noopLocal) GetMetricsIfLocal(ctx context.Context) (*types.StreamingNodeMetrics, error) {
return &types.StreamingNodeMetrics{}, nil
}
type noopBroadcast struct{}
func (n *noopBroadcast) Append(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
return &types.BroadcastAppendResult{
BroadcastID: 1,
AppendResults: map[string]*types.AppendResult{
"v1": {
MessageID: rmq.NewRmqID(1),
TimeTick: 10,
Extra: &anypb.Any{},
},
},
}, nil
}
func (n *noopBroadcast) Ack(ctx context.Context, req types.BroadcastAckRequest) error {
return nil
}
type noopTxn struct{}
func (n *noopTxn) Append(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) error {
if err := getExpectErr(); err != nil {
return err
}
return nil
}
func (n *noopTxn) Commit(ctx context.Context) (*types.AppendResult, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
return &types.AppendResult{}, nil
}
func (n *noopTxn) Rollback(ctx context.Context) error {
if err := getExpectErr(); err != nil {
return err
}
return nil
}
type noopWALAccesser struct{}
func (n *noopWALAccesser) WALName() string {
return "noop"
}
func (n *noopWALAccesser) Local() Local {
return &noopLocal{}
}
func (n *noopWALAccesser) Txn(ctx context.Context, opts TxnOption) (Txn, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
return &noopTxn{}, nil
}
func (n *noopWALAccesser) RawAppend(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) {
if err := getExpectErr(); err != nil {
return nil, err
}
extra, _ := anypb.New(&messagespb.ManualFlushExtraResponse{
SegmentIds: []int64{1, 2, 3},
})
return &types.AppendResult{
MessageID: rmq.NewRmqID(1),
TimeTick: 10,
Extra: extra,
}, nil
}
func (n *noopWALAccesser) Broadcast() Broadcast {
return &noopBroadcast{}
}
func (n *noopWALAccesser) Read(ctx context.Context, opts ReadOption) Scanner {
return &noopScanner{}
}
func (n *noopWALAccesser) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses {
if err := getExpectErr(); err != nil {
return AppendResponses{
Responses: []AppendResponse{
{
AppendResult: nil,
Error: err,
},
},
}
}
return AppendResponses{}
}
func (n *noopWALAccesser) AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses {
return AppendResponses{}
}
type noopScanner struct{}
func (n *noopScanner) Done() <-chan struct{} {
return make(chan struct{})
}
func (n *noopScanner) Error() error {
return nil
}
func (n *noopScanner) Close() {
}
// getExpectErr is a helper function to get the error from the expectErr channel.
func getExpectErr() error {
select {
case err := <-expectErr:
return err
default:
return nil
}
}

View File

@ -11,7 +11,6 @@ import (
"github.com/milvus-io/milvus/internal/distributed/streaming/internal/producer"
"github.com/milvus-io/milvus/internal/streamingcoord/client"
"github.com/milvus-io/milvus/internal/streamingnode/client/handler"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -29,11 +28,7 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl {
// Create a new streaming coord client.
streamingCoordClient := client.NewClient(c)
// Create a new streamingnode handler client.
var handlerClient handler.HandlerClient
if streamingutil.IsStreamingServiceEnabled() {
// streaming service is enabled, create the handler client for the streaming service.
handlerClient = handler.NewHandlerClient(streamingCoordClient.Assignment())
}
handlerClient := handler.NewHandlerClient(streamingCoordClient.Assignment())
w := &walAccesserImpl{
lifetime: typeutil.NewLifetime(),
streamingCoordClient: streamingCoordClient,

View File

@ -4,27 +4,21 @@ import (
"context"
"fmt"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/flushcommon/io"
"github.com/milvus-io/milvus/internal/flushcommon/metacache"
"github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle"
"github.com/milvus-io/milvus/internal/flushcommon/syncmgr"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type l0WriteBuffer struct {
@ -54,94 +48,6 @@ func NewL0WriteBuffer(channel string, metacache metacache.MetaCache, syncMgr syn
}, nil
}
func (wb *l0WriteBuffer) dispatchDeleteMsgs(groups []*InsertData, deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt()
split := func(pks []storage.PrimaryKey, pkTss []uint64, partitionSegments []*metacache.SegmentInfo, partitionGroups []*InsertData) []bool {
lc := storage.NewBatchLocationsCache(pks)
// use hits to cache result
hits := make([]bool, len(pks))
for _, segment := range partitionSegments {
hits = segment.GetBloomFilterSet().BatchPkExistWithHits(lc, hits)
}
for _, inData := range partitionGroups {
hits = inData.batchPkExists(pks, pkTss, hits)
}
return hits
}
type BatchApplyRet = struct {
// represent the idx for delete msg in deleteMsgs
DeleteDataIdx int
// represent the start idx for the batch in each deleteMsg
StartIdx int
Hits []bool
}
// transform pk to primary key
pksInDeleteMsgs := lo.Map(deleteMsgs, func(delMsg *msgstream.DeleteMsg, _ int) []storage.PrimaryKey {
return storage.ParseIDs2PrimaryKeys(delMsg.GetPrimaryKeys())
})
retIdx := 0
retMap := typeutil.NewConcurrentMap[int, *BatchApplyRet]()
pool := io.GetBFApplyPool()
var futures []*conc.Future[any]
for didx, delMsg := range deleteMsgs {
pks := pksInDeleteMsgs[didx]
pkTss := delMsg.GetTimestamps()
partitionSegments := wb.metaCache.GetSegmentsBy(metacache.WithPartitionID(delMsg.PartitionID),
metacache.WithSegmentState(commonpb.SegmentState_Growing, commonpb.SegmentState_Sealed, commonpb.SegmentState_Flushing, commonpb.SegmentState_Flushed))
partitionGroups := lo.Filter(groups, func(inData *InsertData, _ int) bool {
return delMsg.GetPartitionID() == common.AllPartitionsID || delMsg.GetPartitionID() == inData.partitionID
})
for idx := 0; idx < len(pks); idx += batchSize {
startIdx := idx
endIdx := idx + batchSize
if endIdx > len(pks) {
endIdx = len(pks)
}
retIdx += 1
tmpRetIdx := retIdx
deleteDataId := didx
future := pool.Submit(func() (any, error) {
hits := split(pks[startIdx:endIdx], pkTss[startIdx:endIdx], partitionSegments, partitionGroups)
retMap.Insert(tmpRetIdx, &BatchApplyRet{
DeleteDataIdx: deleteDataId,
StartIdx: startIdx,
Hits: hits,
})
return nil, nil
})
futures = append(futures, future)
}
}
conc.AwaitAll(futures...)
retMap.Range(func(key int, value *BatchApplyRet) bool {
l0SegmentID := wb.getL0SegmentID(deleteMsgs[value.DeleteDataIdx].GetPartitionID(), startPos)
pks := pksInDeleteMsgs[value.DeleteDataIdx]
pkTss := deleteMsgs[value.DeleteDataIdx].GetTimestamps()
var deletePks []storage.PrimaryKey
var deleteTss []typeutil.Timestamp
for i, hit := range value.Hits {
if hit {
deletePks = append(deletePks, pks[value.StartIdx+i])
deleteTss = append(deleteTss, pkTss[value.StartIdx+i])
}
}
if len(deletePks) > 0 {
wb.bufferDelete(l0SegmentID, deletePks, deleteTss, startPos, endPos)
}
return true
})
}
func (wb *l0WriteBuffer) dispatchDeleteMsgsWithoutFilter(deleteMsgs []*msgstream.DeleteMsg, startPos, endPos *msgpb.MsgPosition) {
for _, msg := range deleteMsgs {
l0SegmentID := wb.getL0SegmentID(msg.GetPartitionID(), startPos)
@ -165,31 +71,10 @@ func (wb *l0WriteBuffer) BufferData(insertData []*InsertData, deleteMsgs []*msgs
}
}
if paramtable.Get().DataNodeCfg.SkipBFStatsLoad.GetAsBool() || streamingutil.IsStreamingServiceEnabled() {
// In streaming service mode, flushed segments no longer maintain a bloom filter.
// So, here we skip generating BF (growing segment's BF will be regenerated during the sync phase)
// and also skip filtering delete entries by bf.
wb.dispatchDeleteMsgsWithoutFilter(deleteMsgs, startPos, endPos)
} else {
// distribute delete msg
// bf write buffer check bloom filter of segment and current insert batch to decide which segment to write delete data
wb.dispatchDeleteMsgs(insertData, deleteMsgs, startPos, endPos)
// update pk oracle
for _, inData := range insertData {
// segment shall always exists after buffer insert
segments := wb.metaCache.GetSegmentsBy(metacache.WithSegmentIDs(inData.segmentID))
for _, segment := range segments {
for _, fieldData := range inData.pkField {
err := segment.GetBloomFilterSet().UpdatePKRange(fieldData)
if err != nil {
return err
}
}
}
}
}
// In streaming service mode, flushed segments no longer maintain a bloom filter.
// So, here we skip generating BF (growing segment's BF will be regenerated during the sync phase)
// and also skip filtering delete entries by bf.
wb.dispatchDeleteMsgsWithoutFilter(deleteMsgs, startPos, endPos)
// update buffer last checkpoint
wb.checkpoint = endPos

View File

@ -18,7 +18,6 @@ import (
"github.com/milvus-io/milvus/internal/flushcommon/metacache/pkoracle"
"github.com/milvus-io/milvus/internal/flushcommon/syncmgr"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
@ -324,10 +323,8 @@ func (wb *writeBufferBase) syncSegments(ctx context.Context, segmentIDs []int64)
}
if syncTask.IsFlush() {
if paramtable.Get().DataNodeCfg.SkipBFStatsLoad.GetAsBool() || streamingutil.IsStreamingServiceEnabled() {
wb.metaCache.RemoveSegments(metacache.WithSegmentIDs(syncTask.SegmentID()))
log.Info("flushed segment removed", zap.Int64("segmentID", syncTask.SegmentID()), zap.String("channel", syncTask.ChannelName()))
}
wb.metaCache.RemoveSegments(metacache.WithSegmentIDs(syncTask.SegmentID()))
log.Info("flushed segment removed", zap.Int64("segmentID", syncTask.SegmentID()), zap.String("channel", syncTask.ChannelName()))
}
return nil
})

View File

@ -0,0 +1,156 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_streaming
import mock "github.com/stretchr/testify/mock"
// MockScanner is an autogenerated mock type for the Scanner type
type MockScanner struct {
mock.Mock
}
type MockScanner_Expecter struct {
mock *mock.Mock
}
func (_m *MockScanner) EXPECT() *MockScanner_Expecter {
return &MockScanner_Expecter{mock: &_m.Mock}
}
// Close provides a mock function with no fields
func (_m *MockScanner) Close() {
_m.Called()
}
// MockScanner_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockScanner_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockScanner_Expecter) Close() *MockScanner_Close_Call {
return &MockScanner_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockScanner_Close_Call) Run(run func()) *MockScanner_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScanner_Close_Call) Return() *MockScanner_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockScanner_Close_Call) RunAndReturn(run func()) *MockScanner_Close_Call {
_c.Run(run)
return _c
}
// Done provides a mock function with no fields
func (_m *MockScanner) Done() <-chan struct{} {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Done")
}
var r0 <-chan struct{}
if rf, ok := ret.Get(0).(func() <-chan struct{}); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(<-chan struct{})
}
}
return r0
}
// MockScanner_Done_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Done'
type MockScanner_Done_Call struct {
*mock.Call
}
// Done is a helper method to define mock.On call
func (_e *MockScanner_Expecter) Done() *MockScanner_Done_Call {
return &MockScanner_Done_Call{Call: _e.mock.On("Done")}
}
func (_c *MockScanner_Done_Call) Run(run func()) *MockScanner_Done_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScanner_Done_Call) Return(_a0 <-chan struct{}) *MockScanner_Done_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockScanner_Done_Call) RunAndReturn(run func() <-chan struct{}) *MockScanner_Done_Call {
_c.Call.Return(run)
return _c
}
// Error provides a mock function with no fields
func (_m *MockScanner) Error() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Error")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MockScanner_Error_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Error'
type MockScanner_Error_Call struct {
*mock.Call
}
// Error is a helper method to define mock.On call
func (_e *MockScanner_Expecter) Error() *MockScanner_Error_Call {
return &MockScanner_Error_Call{Call: _e.mock.On("Error")}
}
func (_c *MockScanner_Error_Call) Run(run func()) *MockScanner_Error_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScanner_Error_Call) Return(_a0 error) *MockScanner_Error_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockScanner_Error_Call) RunAndReturn(run func() error) *MockScanner_Error_Call {
_c.Call.Return(run)
return _c
}
// NewMockScanner creates a new instance of MockScanner. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockScanner(t interface {
mock.TestingT
Cleanup(func())
}) *MockScanner {
mock := &MockScanner{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -16,8 +16,6 @@ import (
mock "github.com/stretchr/testify/mock"
types "github.com/milvus-io/milvus/internal/types"
workerpb "github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
)
@ -1918,52 +1916,6 @@ func (_c *MockDataNode_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clien
return _c
}
// SetMixCoordClient provides a mock function with given fields: mixCoord
func (_m *MockDataNode) SetMixCoordClient(mixCoord types.MixCoordClient) error {
ret := _m.Called(mixCoord)
if len(ret) == 0 {
panic("no return value specified for SetMixCoordClient")
}
var r0 error
if rf, ok := ret.Get(0).(func(types.MixCoordClient) error); ok {
r0 = rf(mixCoord)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockDataNode_SetMixCoordClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetMixCoordClient'
type MockDataNode_SetMixCoordClient_Call struct {
*mock.Call
}
// SetMixCoordClient is a helper method to define mock.On call
// - mixCoord types.MixCoordClient
func (_e *MockDataNode_Expecter) SetMixCoordClient(mixCoord interface{}) *MockDataNode_SetMixCoordClient_Call {
return &MockDataNode_SetMixCoordClient_Call{Call: _e.mock.On("SetMixCoordClient", mixCoord)}
}
func (_c *MockDataNode_SetMixCoordClient_Call) Run(run func(mixCoord types.MixCoordClient)) *MockDataNode_SetMixCoordClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(types.MixCoordClient))
})
return _c
}
func (_c *MockDataNode_SetMixCoordClient_Call) Return(_a0 error) *MockDataNode_SetMixCoordClient_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockDataNode_SetMixCoordClient_Call) RunAndReturn(run func(types.MixCoordClient) error) *MockDataNode_SetMixCoordClient_Call {
_c.Call.Return(run)
return _c
}
// ShowConfigurations provides a mock function with given fields: _a0, _a1
func (_m *MockDataNode) ShowConfigurations(_a0 context.Context, _a1 *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) {
ret := _m.Called(_a0, _a1)

View File

@ -6,7 +6,6 @@ import (
context "context"
commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
clientv3 "go.etcd.io/etcd/client/v3"
federpb "github.com/milvus-io/milvus-proto/go-api/v2/federpb"
@ -6690,39 +6689,6 @@ func (_c *MockProxy_SetAddress_Call) RunAndReturn(run func(string)) *MockProxy_S
return _c
}
// SetEtcdClient provides a mock function with given fields: etcdClient
func (_m *MockProxy) SetEtcdClient(etcdClient *clientv3.Client) {
_m.Called(etcdClient)
}
// MockProxy_SetEtcdClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetEtcdClient'
type MockProxy_SetEtcdClient_Call struct {
*mock.Call
}
// SetEtcdClient is a helper method to define mock.On call
// - etcdClient *clientv3.Client
func (_e *MockProxy_Expecter) SetEtcdClient(etcdClient interface{}) *MockProxy_SetEtcdClient_Call {
return &MockProxy_SetEtcdClient_Call{Call: _e.mock.On("SetEtcdClient", etcdClient)}
}
func (_c *MockProxy_SetEtcdClient_Call) Run(run func(etcdClient *clientv3.Client)) *MockProxy_SetEtcdClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*clientv3.Client))
})
return _c
}
func (_c *MockProxy_SetEtcdClient_Call) Return() *MockProxy_SetEtcdClient_Call {
_c.Call.Return()
return _c
}
func (_c *MockProxy_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Client)) *MockProxy_SetEtcdClient_Call {
_c.Run(run)
return _c
}
// SetMixCoordClient provides a mock function with given fields: rootCoord
func (_m *MockProxy) SetMixCoordClient(rootCoord types.MixCoordClient) {
_m.Called(rootCoord)

View File

@ -39,9 +39,7 @@ import (
type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error)
getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error)
removeDMLStream(collectionID UniqueID)
removeAllDMLStream()
}
type channelInfos struct {
@ -52,7 +50,6 @@ type channelInfos struct {
type streamInfos struct {
channelInfos channelInfos
stream msgstream.MsgStream
}
func removeDuplicate(ss []string) []string {
@ -114,9 +111,8 @@ type singleTypeChannelsMgr struct {
infos map[UniqueID]streamInfos // collection id -> stream infos
mu sync.RWMutex
getChannelsFunc getChannelsFuncType
repackFunc repackFuncType
msgStreamFactory msgstream.Factory
getChannelsFunc getChannelsFuncType
repackFunc repackFuncType
}
func (mgr *singleTypeChannelsMgr) getAllChannels(collectionID UniqueID) (channelInfos, error) {
@ -167,27 +163,6 @@ func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan,
return channelInfos.vchans, nil
}
func (mgr *singleTypeChannelsMgr) streamExistPrivate(collectionID UniqueID) bool {
streamInfos, ok := mgr.infos[collectionID]
return ok && streamInfos.stream != nil
}
func createStream(ctx context.Context, factory msgstream.Factory, pchans []pChan, repack repackFuncType) (msgstream.MsgStream, error) {
var stream msgstream.MsgStream
var err error
stream, err = factory.NewMsgStream(context.Background())
if err != nil {
return nil, err
}
stream.AsProducer(ctx, pchans)
if repack != nil {
stream.SetRepackFunc(repack)
}
return stream, nil
}
func incPChansMetrics(pchans []pChan) {
for _, pc := range pchans {
metrics.ProxyMsgStreamObjectsForPChan.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), pc).Inc()
@ -200,67 +175,6 @@ func decPChanMetrics(pchans []pChan) {
}
}
// createMsgStream create message stream for specified collection. Idempotent.
// If stream already exists, directly return it and no error will be returned.
func (mgr *singleTypeChannelsMgr) createMsgStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
infos, ok := mgr.infos[collectionID]
if ok && infos.stream != nil {
// already exist.
mgr.mu.RUnlock()
return infos.stream, nil
}
mgr.mu.RUnlock()
channelInfos, err := mgr.getChannelsFunc(collectionID)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to get channels", zap.Error(err), zap.Int64("collection", collectionID))
return nil, err
}
stream, err := createStream(ctx, mgr.msgStreamFactory, channelInfos.pchans, mgr.repackFunc)
if err != nil {
// What if stream created by other goroutines?
log.Error("failed to create message stream", zap.Error(err), zap.Int64("collection", collectionID))
return nil, err
}
mgr.mu.Lock()
defer mgr.mu.Unlock()
if !mgr.streamExistPrivate(collectionID) {
log.Info("create message stream", zap.Int64("collection", collectionID),
zap.Strings("virtual_channels", channelInfos.vchans),
zap.Strings("physical_channels", channelInfos.pchans))
mgr.infos[collectionID] = streamInfos{channelInfos: channelInfos, stream: stream}
incPChansMetrics(channelInfos.pchans)
} else {
stream.Close()
}
return mgr.infos[collectionID].stream, nil
}
func (mgr *singleTypeChannelsMgr) lockGetStream(collectionID UniqueID) (msgstream.MsgStream, error) {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
streamInfos, ok := mgr.infos[collectionID]
if ok {
return streamInfos.stream, nil
}
return nil, fmt.Errorf("collection not found: %d", collectionID)
}
// getOrCreateStream get message stream of specified collection.
// If stream doesn't exist, call createMsgStream to create for it.
func (mgr *singleTypeChannelsMgr) getOrCreateStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
if stream, err := mgr.lockGetStream(collectionID); err == nil {
return stream, nil
}
return mgr.createMsgStream(ctx, collectionID)
}
// removeStream remove the corresponding stream of the specified collection. Idempotent.
// If stream already exists, remove it, otherwise do nothing.
func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) {
@ -268,34 +182,19 @@ func (mgr *singleTypeChannelsMgr) removeStream(collectionID UniqueID) {
defer mgr.mu.Unlock()
if info, ok := mgr.infos[collectionID]; ok {
decPChanMetrics(info.channelInfos.pchans)
info.stream.Close()
delete(mgr.infos, collectionID)
}
log.Info("dml stream removed", zap.Int64("collection_id", collectionID))
}
// removeAllStream remove all message stream.
func (mgr *singleTypeChannelsMgr) removeAllStream() {
mgr.mu.Lock()
defer mgr.mu.Unlock()
for _, info := range mgr.infos {
info.stream.Close()
decPChanMetrics(info.channelInfos.pchans)
}
mgr.infos = make(map[UniqueID]streamInfos)
log.Info("all dml stream removed")
}
func newSingleTypeChannelsMgr(
getChannelsFunc getChannelsFuncType,
msgStreamFactory msgstream.Factory,
repackFunc repackFuncType,
) *singleTypeChannelsMgr {
return &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: getChannelsFunc,
repackFunc: repackFunc,
msgStreamFactory: msgStreamFactory,
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: getChannelsFunc,
repackFunc: repackFunc,
}
}
@ -315,25 +214,16 @@ func (mgr *channelsMgrImpl) getVChannels(collectionID UniqueID) ([]vChan, error)
return mgr.dmlChannelsMgr.getVChannels(collectionID)
}
func (mgr *channelsMgrImpl) getOrCreateDmlStream(ctx context.Context, collectionID UniqueID) (msgstream.MsgStream, error) {
return mgr.dmlChannelsMgr.getOrCreateStream(ctx, collectionID)
}
func (mgr *channelsMgrImpl) removeDMLStream(collectionID UniqueID) {
mgr.dmlChannelsMgr.removeStream(collectionID)
}
func (mgr *channelsMgrImpl) removeAllDMLStream() {
mgr.dmlChannelsMgr.removeAllStream()
}
// newChannelsMgrImpl constructs a channels manager.
func newChannelsMgrImpl(
getDmlChannelsFunc getChannelsFuncType,
dmlRepackFunc repackFuncType,
msgStreamFactory msgstream.Factory,
) *channelsMgrImpl {
return &channelsMgrImpl{
dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, msgStreamFactory, dmlRepackFunc),
dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, dmlRepackFunc),
}
}

View File

@ -18,7 +18,6 @@ package proxy
import (
"context"
"sync"
"testing"
"github.com/cockroachdb/errors"
@ -28,8 +27,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func Test_removeDuplicate(t *testing.T) {
@ -205,220 +202,11 @@ func Test_singleTypeChannelsMgr_getVChannels(t *testing.T) {
})
}
func Test_createStream(t *testing.T) {
t.Run("failed to create msgstream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.fQStream = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})
t.Run("failed to create query msgstream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
_, err := createStream(context.TODO(), factory, nil, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
_, err := createStream(context.TODO(), factory, []string{"111"}, func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
return nil, nil
})
assert.NoError(t, err)
})
}
func Test_singleTypeChannelsMgr_createMsgStream(t *testing.T) {
paramtable.Init()
t.Run("re-create", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
t.Run("concurrent create", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
stopCh := make(chan struct{})
readyCh := make(chan struct{})
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
close(readyCh)
<-stopCh
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
firstStream := streamInfos{stream: newMockMsgStream()}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
}()
// make sure create msg stream has run at getchannels
<-readyCh
// mock create stream for same collection in same time.
m.mu.Lock()
m.infos[100] = firstStream
m.mu.Unlock()
close(stopCh)
wg.Wait()
})
t.Run("failed to get channels", func(t *testing.T) {
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})
t.Run("failed to create message stream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock")
}
m := &singleTypeChannelsMgr{
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
_, err := m.createMsgStream(context.TODO(), 100)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.createMsgStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
stream, err = m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_lockGetStream(t *testing.T) {
t.Run("collection not found", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
}
_, err := m.lockGetStream(100)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.lockGetStream(100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_getStream(t *testing.T) {
t.Run("exist", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {stream: newMockMsgStream()},
},
}
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
t.Run("failed to create", func(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{},
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{}, errors.New("mock")
},
}
_, err := m.getOrCreateStream(context.TODO(), 100)
assert.Error(t, err)
})
t.Run("get after create", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return newMockMsgStream(), nil
}
m := &singleTypeChannelsMgr{
infos: make(map[UniqueID]streamInfos),
getChannelsFunc: func(collectionID UniqueID) (channelInfos, error) {
return channelInfos{vchans: []string{"111", "222"}, pchans: []string{"111"}}, nil
},
msgStreamFactory: factory,
repackFunc: nil,
}
stream, err := m.getOrCreateStream(context.TODO(), 100)
assert.NoError(t, err)
assert.NotNil(t, stream)
})
}
func Test_singleTypeChannelsMgr_removeStream(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {
stream: newMockMsgStream(),
},
100: {},
},
}
m.removeStream(100)
_, err := m.lockGetStream(100)
assert.Error(t, err)
}
func Test_singleTypeChannelsMgr_removeAllStream(t *testing.T) {
m := &singleTypeChannelsMgr{
infos: map[UniqueID]streamInfos{
100: {
stream: newMockMsgStream(),
},
},
}
m.removeAllStream()
_, err := m.lockGetStream(100)
assert.Error(t, err)
}

View File

@ -18,7 +18,6 @@ package proxy
import (
"context"
"encoding/base64"
"fmt"
"os"
"strconv"
@ -46,7 +45,6 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/ctokenizer"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
@ -255,7 +253,6 @@ func (node *Proxy) CreateDatabase(ctx context.Context, request *milvuspb.CreateD
Condition: NewTaskCondition(ctx),
CreateDatabaseRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -323,7 +320,6 @@ func (node *Proxy) DropDatabase(ctx context.Context, request *milvuspb.DropDatab
Condition: NewTaskCondition(ctx),
DropDatabaseRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -452,7 +448,6 @@ func (node *Proxy) AlterDatabase(ctx context.Context, request *milvuspb.AlterDat
Condition: NewTaskCondition(ctx),
AlterDatabaseRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -667,7 +662,6 @@ func (node *Proxy) DropCollection(ctx context.Context, request *milvuspb.DropCol
DropCollectionRequest: request,
mixCoord: node.mixCoord,
chMgr: node.chMgr,
chTicker: node.chTicker,
}
log := log.Ctx(ctx).With(
@ -833,7 +827,6 @@ func (node *Proxy) LoadCollection(ctx context.Context, request *milvuspb.LoadCol
Condition: NewTaskCondition(ctx),
LoadCollectionRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -908,7 +901,6 @@ func (node *Proxy) ReleaseCollection(ctx context.Context, request *milvuspb.Rele
Condition: NewTaskCondition(ctx),
ReleaseCollectionRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -1278,7 +1270,6 @@ func (node *Proxy) AlterCollection(ctx context.Context, request *milvuspb.AlterC
Condition: NewTaskCondition(ctx),
AlterCollectionRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -1343,7 +1334,6 @@ func (node *Proxy) AlterCollectionField(ctx context.Context, request *milvuspb.A
Condition: NewTaskCondition(ctx),
AlterCollectionFieldRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -1616,7 +1606,6 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar
Condition: NewTaskCondition(ctx),
LoadPartitionsRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
log := log.Ctx(ctx).With(
@ -1682,7 +1671,6 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele
Condition: NewTaskCondition(ctx),
ReleasePartitionsRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
method := "ReleasePartitions"
@ -2082,11 +2070,10 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
defer sp.End()
cit := &createIndexTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
req: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
ctx: ctx,
Condition: NewTaskCondition(ctx),
req: request,
mixCoord: node.mixCoord,
}
method := "CreateIndex"
@ -2152,11 +2139,10 @@ func (node *Proxy) AlterIndex(ctx context.Context, request *milvuspb.AlterIndexR
defer sp.End()
task := &alterIndexTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
req: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
ctx: ctx,
Condition: NewTaskCondition(ctx),
req: request,
mixCoord: node.mixCoord,
}
method := "AlterIndex"
@ -2370,11 +2356,10 @@ func (node *Proxy) DropIndex(ctx context.Context, request *milvuspb.DropIndexReq
defer sp.End()
dit := &dropIndexTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
DropIndexRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
ctx: ctx,
Condition: NewTaskCondition(ctx),
DropIndexRequest: request,
mixCoord: node.mixCoord,
}
method := "DropIndex"
@ -2630,17 +2615,9 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
},
},
idAllocator: node.rowIDAllocator,
segIDAssigner: node.segAssigner,
chMgr: node.chMgr,
chTicker: node.chTicker,
schemaTimestamp: request.SchemaTimestamp,
}
var enqueuedTask task = it
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &insertTaskByStreamingService{
insertTask: it,
}
}
constructFailedResponse := func(err error) *milvuspb.MutationResult {
numRows := request.NumRows
@ -2657,7 +2634,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
log.Debug("Enqueue insert request in Proxy")
if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil {
if err := node.sched.dmQueue.Enqueue(it); err != nil {
log.Warn("Failed to enqueue insert task: " + err.Error())
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
metrics.AbandonLabel, request.GetDbName(), request.GetCollectionName()).Inc()
@ -2765,7 +2742,6 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
idAllocator: node.rowIDAllocator,
tsoAllocatorIns: node.tsoAllocator,
chMgr: node.chMgr,
chTicker: node.chTicker,
queue: node.sched.dmQueue,
lb: node.lbPolicy,
limiter: limiter,
@ -2874,23 +2850,15 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest)
},
idAllocator: node.rowIDAllocator,
segIDAssigner: node.segAssigner,
chMgr: node.chMgr,
chTicker: node.chTicker,
schemaTimestamp: request.SchemaTimestamp,
}
var enqueuedTask task = it
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &upsertTaskByStreamingService{
upsertTask: it,
}
}
log.Debug("Enqueue upsert request in Proxy",
zap.Int("len(FieldsData)", len(request.FieldsData)),
zap.Int("len(HashKeys)", len(request.HashKeys)))
if err := node.sched.dmQueue.Enqueue(enqueuedTask); err != nil {
if err := node.sched.dmQueue.Enqueue(it); err != nil {
log.Info("Failed to enqueue upsert task",
zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method,
@ -3561,11 +3529,11 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
defer sp.End()
ft := &flushTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
FlushRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
ctx: ctx,
Condition: NewTaskCondition(ctx),
FlushRequest: request,
mixCoord: node.mixCoord,
chMgr: node.chMgr,
}
method := "Flush"
@ -3578,16 +3546,7 @@ func (node *Proxy) Flush(ctx context.Context, request *milvuspb.FlushRequest) (*
zap.Any("collections", request.CollectionNames))
log.Debug(rpcReceived(method))
var enqueuedTask task = ft
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &flushTaskByStreamingService{
flushTask: ft,
chMgr: node.chMgr,
}
}
if err := node.sched.dcQueue.Enqueue(enqueuedTask); err != nil {
if err := node.sched.dcQueue.Enqueue(ft); err != nil {
log.Warn(
rpcFailedToEnqueue(method),
zap.Error(err))
@ -3846,7 +3805,6 @@ func (node *Proxy) CreateAlias(ctx context.Context, request *milvuspb.CreateAlia
Condition: NewTaskCondition(ctx),
CreateAliasRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
}
method := "CreateAlias"
@ -4034,11 +3992,10 @@ func (node *Proxy) DropAlias(ctx context.Context, request *milvuspb.DropAliasReq
defer sp.End()
dat := &DropAliasTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
DropAliasRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
ctx: ctx,
Condition: NewTaskCondition(ctx),
DropAliasRequest: request,
mixCoord: node.mixCoord,
}
method := "DropAlias"
@ -4098,11 +4055,10 @@ func (node *Proxy) AlterAlias(ctx context.Context, request *milvuspb.AlterAliasR
defer sp.End()
aat := &AlterAliasTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
AlterAliasRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
ctx: ctx,
Condition: NewTaskCondition(ctx),
AlterAliasRequest: request,
mixCoord: node.mixCoord,
}
method := "AlterAlias"
@ -4174,11 +4130,11 @@ func (node *Proxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllReque
defer sp.End()
ft := &flushAllTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
FlushAllRequest: request,
mixCoord: node.mixCoord,
replicateMsgStream: node.replicateMsgStream,
ctx: ctx,
Condition: NewTaskCondition(ctx),
FlushAllRequest: request,
mixCoord: node.mixCoord,
chMgr: node.chMgr,
}
method := "FlushAll"
@ -4191,15 +4147,7 @@ func (node *Proxy) FlushAll(ctx context.Context, request *milvuspb.FlushAllReque
log.Debug(rpcReceived(method))
var enqueuedTask task = ft
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &flushAllTaskbyStreamingService{
flushAllTask: ft,
chMgr: node.chMgr,
}
}
if err := node.sched.dcQueue.Enqueue(enqueuedTask); err != nil {
if err := node.sched.dcQueue.Enqueue(ft); err != nil {
log.Warn(rpcFailedToEnqueue(method), zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.AbandonLabel, request.GetDbName(), "").Inc()
resp.Status = merr.Status(err)
@ -5198,9 +5146,6 @@ func (node *Proxy) CreateCredential(ctx context.Context, req *milvuspb.CreateCre
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err
}
@ -5273,9 +5218,6 @@ func (node *Proxy) UpdateCredential(ctx context.Context, req *milvuspb.UpdateCre
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err
}
@ -5306,9 +5248,6 @@ func (node *Proxy) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCre
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, err
}
@ -5372,9 +5311,6 @@ func (node *Proxy) CreateRole(ctx context.Context, req *milvuspb.CreateRoleReque
log.Warn("fail to create role", zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}
@ -5407,9 +5343,6 @@ func (node *Proxy) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest)
zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}
@ -5439,9 +5372,6 @@ func (node *Proxy) OperateUserRole(ctx context.Context, req *milvuspb.OperateUse
log.Warn("fail to operate user role", zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}
@ -5624,9 +5554,6 @@ func (node *Proxy) OperatePrivilegeV2(ctx context.Context, req *milvuspb.Operate
}
}
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}
@ -5675,9 +5602,6 @@ func (node *Proxy) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePr
}
}
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}
@ -5914,11 +5838,6 @@ func (node *Proxy) RenameCollection(ctx context.Context, req *milvuspb.RenameCol
log.Warn("failed to rename collection", zap.Error(err))
return merr.Status(err), err
}
if merr.Ok(resp) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return resp, nil
}
@ -6432,136 +6351,9 @@ func (node *Proxy) Connect(ctx context.Context, request *milvuspb.ConnectRequest
}
func (node *Proxy) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
var err error
if req.GetChannelName() == "" {
log.Ctx(ctx).Warn("channel name is empty")
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.WrapErrParameterInvalidMsg("invalid channel name for the replicate message request")),
}, nil
}
// get the latest position of the replicate msg channel
replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
if req.GetChannelName() == replicateMsgChannel {
msgID, err := msgstream.GetChannelLatestMsgID(ctx, node.factory, replicateMsgChannel)
if err != nil {
log.Ctx(ctx).Warn("failed to get the latest message id of the replicate msg channel", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
position := &msgpb.MsgPosition{
ChannelName: replicateMsgChannel,
MsgID: msgID,
}
positionBytes, err := proto.Marshal(position)
if err != nil {
log.Ctx(ctx).Warn("failed to marshal position", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(nil),
Position: base64.StdEncoding.EncodeToString(positionBytes),
}, nil
}
collectionReplicateEnable := paramtable.Get().CommonCfg.CollectionReplicateEnable.GetAsBool()
ttMsgEnabled := paramtable.Get().CommonCfg.TTMsgEnabled.GetAsBool()
// replicate message can be use in two ways, otherwise return error
// 1. collectionReplicateEnable is false and ttMsgEnabled is false, active/standby mode
// 2. collectionReplicateEnable is true and ttMsgEnabled is true, data migration mode
if (!collectionReplicateEnable && ttMsgEnabled) || (collectionReplicateEnable && !ttMsgEnabled) {
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.ErrDenyReplicateMessage),
}, nil
}
msgPack := &msgstream.MsgPack{
BeginTs: req.BeginTs,
EndTs: req.EndTs,
Msgs: make([]msgstream.TsMsg, 0),
StartPositions: req.StartPositions,
EndPositions: req.EndPositions,
}
checkCollectionReplicateProperty := func(dbName, collectionName string) bool {
if !collectionReplicateEnable {
return true
}
replicateID, err := GetReplicateID(ctx, dbName, collectionName)
if err != nil {
log.Warn("get replicate id failed", zap.String("collectionName", collectionName), zap.Error(err))
return false
}
return replicateID != ""
}
// getTsMsgFromConsumerMsg
for i, msgBytes := range req.Msgs {
header := commonpb.MsgHeader{}
err = proto.Unmarshal(msgBytes, &header)
if err != nil {
log.Ctx(ctx).Warn("failed to unmarshal msg header", zap.Int("index", i), zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
if header.GetBase() == nil {
log.Ctx(ctx).Warn("msg header base is nil", zap.Int("index", i))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
}
tsMsg, err := node.replicateStreamManager.GetMsgDispatcher().Unmarshal(msgBytes, header.GetBase().GetMsgType())
if err != nil {
log.Ctx(ctx).Warn("failed to unmarshal msg", zap.Int("index", i), zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrInvalidMsgBytes)}, nil
}
switch realMsg := tsMsg.(type) {
case *msgstream.InsertMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
assignedSegmentInfos, err := node.segAssigner.GetSegmentID(realMsg.GetCollectionID(), realMsg.GetPartitionID(),
realMsg.GetShardName(), uint32(realMsg.NumRows), req.EndTs)
if err != nil {
log.Ctx(ctx).Warn("failed to get segment id", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
if len(assignedSegmentInfos) == 0 {
log.Ctx(ctx).Warn("no segment id assigned")
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.ErrNoAssignSegmentID)}, nil
}
for assignSegmentID := range assignedSegmentInfos {
realMsg.SegmentID = assignSegmentID
break
}
case *msgstream.DeleteMsg:
if !checkCollectionReplicateProperty(realMsg.GetDbName(), realMsg.GetCollectionName()) {
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(merr.WrapErrCollectionReplicateMode("replicate"))}, nil
}
}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
}
msgStream, err := node.replicateStreamManager.GetReplicateMsgStream(ctx, req.ChannelName)
if err != nil {
log.Ctx(ctx).Warn("failed to get msg stream from the replicate stream manager", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(err),
}, nil
}
messageIDsMap, err := msgStream.Broadcast(ctx, msgPack)
if err != nil {
log.Ctx(ctx).Warn("failed to produce msg", zap.Error(err))
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(err)}, nil
}
var position string
if len(messageIDsMap[req.GetChannelName()]) == 0 {
log.Ctx(ctx).Warn("no message id returned")
} else {
messageIDs := messageIDsMap[req.GetChannelName()]
position = base64.StdEncoding.EncodeToString(messageIDs[len(messageIDs)-1].Serialize())
}
return &milvuspb.ReplicateMessageResponse{Status: merr.Status(nil), Position: position}, nil
return &milvuspb.ReplicateMessageResponse{
Status: merr.Status(merr.WrapErrServiceUnavailable("not supported in streaming mode")),
}, nil
}
func (node *Proxy) ListClientInfos(ctx context.Context, req *proxypb.ListClientInfosRequest) (*proxypb.ListClientInfosResponse, error) {
@ -6821,9 +6613,6 @@ func (node *Proxy) CreatePrivilegeGroup(ctx context.Context, req *milvuspb.Creat
log.Warn("fail to create privilege group", zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}
@ -6853,9 +6642,6 @@ func (node *Proxy) DropPrivilegeGroup(ctx context.Context, req *milvuspb.DropPri
log.Warn("fail to drop privilege group", zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}
@ -6919,9 +6705,6 @@ func (node *Proxy) OperatePrivilegeGroup(ctx context.Context, req *milvuspb.Oper
log.Warn("fail to operate privilege group", zap.Error(err))
return merr.Status(err), nil
}
if merr.Ok(result) {
SendReplicateMessagePack(ctx, node.replicateMsgStream, req)
}
return result, nil
}

View File

@ -18,12 +18,10 @@ package proxy
import (
"context"
"encoding/base64"
"math/rand"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
@ -31,38 +29,27 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
"github.com/milvus-io/milvus/internal/distributed/streaming"
mhttp "github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
mqcommon "github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/v2/util/resource"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -254,7 +241,7 @@ func TestProxy_ResourceGroup(t *testing.T) {
qc.EXPECT().ShowLoadCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
tsoAllocatorIns := newMockTsoAllocator()
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns, node.factory)
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns)
assert.NoError(t, err)
node.sched.Start()
defer node.sched.Close()
@ -336,7 +323,7 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
Status: merr.Success(),
}, nil).Maybe()
tsoAllocatorIns := newMockTsoAllocator()
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns, node.factory)
node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns)
assert.NoError(t, err)
node.sched.Start()
defer node.sched.Close()
@ -393,11 +380,7 @@ func createTestProxy() *Proxy {
tso: newMockTimestampAllocatorInterface(),
}
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, _ = node.factory.NewMsgStream(node.ctx)
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
node.sched, _ = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched, _ = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.Start()
return node
@ -418,10 +401,12 @@ func TestProxy_FlushAll_NoDatabase(t *testing.T) {
mockey.Mock(paramtable.Init).Return().Build()
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
// Mock grpc mix coord client FlushAll method
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) {
return &datapb.FlushAllResponse{Status: successStatus}, nil
mockey.Mock((*grpcmixcoordclient.Client).ListDatabases).To(func(ctx context.Context, req *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{Status: successStatus}, nil
}).Build()
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
return &milvuspb.ShowCollectionsResponse{Status: successStatus}, nil
}).Build()
// Act: Execute test
@ -454,10 +439,13 @@ func TestProxy_FlushAll_WithDefaultDatabase(t *testing.T) {
mockey.Mock(paramtable.Init).Return().Build()
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
// Mock grpc mix coord client FlushAll method for default database
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) {
return &datapb.FlushAllResponse{Status: successStatus}, nil
mockey.Mock((*grpcmixcoordclient.Client).ListDatabases).To(func(ctx context.Context, req *milvuspb.ListDatabasesRequest, opts ...grpc.CallOption) (*milvuspb.ListDatabasesResponse, error) {
return &milvuspb.ListDatabasesResponse{Status: successStatus, DbNames: []string{"default"}}, nil
}).Build()
// Mock grpc mix coord client FlushAll method for default database
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
return &milvuspb.ShowCollectionsResponse{Status: successStatus}, nil
}).Build()
// Act: Execute test
@ -490,9 +478,8 @@ func TestProxy_FlushAll_DatabaseNotExist(t *testing.T) {
mockey.Mock(paramtable.Init).Return().Build()
mockey.Mock((*paramtable.ComponentParam).Save).Return().Build()
// Mock grpc mix coord client FlushAll method for non-existent database
mockey.Mock((*grpcmixcoordclient.Client).FlushAll).To(func(ctx context.Context, req *datapb.FlushAllRequest, opts ...grpc.CallOption) (*datapb.FlushAllResponse, error) {
return &datapb.FlushAllResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil
mockey.Mock((*grpcmixcoordclient.Client).ShowCollections).To(func(ctx context.Context, req *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) {
return &milvuspb.ShowCollectionsResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil
}).Build()
// Act: Execute test
@ -889,18 +876,13 @@ func TestProxyCreateDatabase(t *testing.T) {
}
node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
assert.NoError(t, err)
defer node.sched.Close()
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("create database fail", func(t *testing.T) {
mixc := mocks.NewMockMixCoordClient(t)
mixc.On("CreateDatabase", mock.Anything, mock.Anything).
@ -949,18 +931,13 @@ func TestProxyDropDatabase(t *testing.T) {
}
node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
assert.NoError(t, err)
defer node.sched.Close()
rpcRequestChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
assert.NoError(t, err)
node.replicateMsgStream.AsProducer(ctx, []string{rpcRequestChannel})
t.Run("drop database fail", func(t *testing.T) {
mixc := mocks.NewMockMixCoordClient(t)
mixc.On("DropDatabase", mock.Anything, mock.Anything).
@ -1007,7 +984,7 @@ func TestProxyListDatabase(t *testing.T) {
}
node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
@ -1063,7 +1040,7 @@ func TestProxyAlterDatabase(t *testing.T) {
}
node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
@ -1116,7 +1093,7 @@ func TestProxyDescribeDatabase(t *testing.T) {
}
node.simpleLimiter = NewSimpleLimiter(0, 0)
node.UpdateStateCode(commonpb.StateCode_Healthy)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator, node.factory)
node.sched, err = newTaskScheduler(ctx, node.tsoAllocator)
node.sched.ddQueue.setMaxTaskNum(10)
assert.NoError(t, err)
err = node.sched.Start()
@ -1347,7 +1324,7 @@ func TestProxy_Delete(t *testing.T) {
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
assert.NoError(t, err)
queue, err := newTaskScheduler(ctx, tsoAllocator, nil)
queue, err := newTaskScheduler(ctx, tsoAllocator)
assert.NoError(t, err)
node := &Proxy{chMgr: chMgr, rowIDAllocator: idAllocator, sched: queue}
@ -1358,287 +1335,7 @@ func TestProxy_Delete(t *testing.T) {
})
}
func TestProxy_ReplicateMessage(t *testing.T) {
paramtable.Init()
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
t.Run("proxy unhealthy", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
resp, err := node.ReplicateMessage(context.TODO(), nil)
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("not backup instance", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err := node.ReplicateMessage(context.TODO(), nil)
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("empty channel name", func(t *testing.T) {
node := &Proxy{}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
resp, err := node.ReplicateMessage(context.TODO(), nil)
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("fail to get msg stream", func(t *testing.T) {
factory := newMockMsgStreamFactory()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock error: get msg stream")
}
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
node := &Proxy{
replicateStreamManager: manager,
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{ChannelName: "unit_test_replicate_message"})
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
})
t.Run("get latest position", func(t *testing.T) {
base64DecodeMsgPosition := func(position string) (*msgstream.MsgPosition, error) {
decodeBytes, err := base64.StdEncoding.DecodeString(position)
if err != nil {
log.Warn("fail to decode the position", zap.Error(err))
return nil, err
}
msgPosition := &msgstream.MsgPosition{}
err = proto.Unmarshal(decodeBytes, msgPosition)
if err != nil {
log.Warn("fail to unmarshal the position", zap.Error(err))
return nil, err
}
return msgPosition, nil
}
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
defer paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
factory := dependency.NewMockFactory(t)
stream := msgstream.NewMockMsgStream(t)
mockMsgID := mqcommon.NewMockMessageID(t)
factory.EXPECT().NewMsgStream(mock.Anything).Return(stream, nil).Once()
mockMsgID.EXPECT().Serialize().Return([]byte("mock")).Once()
stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
stream.EXPECT().GetLatestMsgID(mock.Anything).Return(mockMsgID, nil).Once()
stream.EXPECT().Close().Return()
node := &Proxy{
factory: factory,
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(),
})
assert.NoError(t, err)
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
{
p, err := base64DecodeMsgPosition(resp.GetPosition())
assert.NoError(t, err)
assert.Equal(t, []byte("mock"), p.MsgID)
}
factory.EXPECT().NewMsgStream(mock.Anything).Return(nil, errors.New("mock")).Once()
resp, err = node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: Params.CommonCfg.ReplicateMsgChannel.GetValue(),
})
assert.NoError(t, err)
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
})
t.Run("invalid msg pack", func(t *testing.T) {
node := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
{
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
Msgs: [][]byte{{1, 2, 3}},
})
assert.NoError(t, err)
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
}
{
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 1,
EndTimestamp: 10,
HashValues: []uint32{0},
},
TimeTickMsg: &msgpb.TimeTickMsg{},
}
msgBytes, _ := timeTickMsg.Marshal(timeTickMsg)
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
Msgs: [][]byte{msgBytes.([]byte)},
})
assert.NoError(t, err)
log.Info("resp", zap.Any("resp", resp))
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
}
{
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 1,
EndTimestamp: 10,
HashValues: []uint32{0},
},
TimeTickMsg: &msgpb.TimeTickMsg{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType(-1)),
commonpbutil.WithTimeStamp(10),
commonpbutil.WithSourceID(-1),
),
},
}
msgBytes, _ := timeTickMsg.Marshal(timeTickMsg)
resp, err := node.ReplicateMessage(context.TODO(), &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
Msgs: [][]byte{msgBytes.([]byte)},
})
assert.NoError(t, err)
log.Info("resp", zap.Any("resp", resp))
assert.NotEqual(t, 0, resp.GetStatus().GetCode())
}
})
t.Run("success", func(t *testing.T) {
paramtable.Init()
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return()
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return()
msgStreamObj.EXPECT().Close().Return()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2"))
broadcastMock := msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {mockMsgID1, mockMsgID2},
}, nil)
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return msgStreamObj, nil
}
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(1000)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick1)
assert.NoError(t, err)
segAllocator.Start()
node := &Proxy{
replicateStreamManager: manager,
segAssigner: segAllocator,
}
node.UpdateStateCode(commonpb.StateCode_Healthy)
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 4,
EndTimestamp: 10,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "unit_test_replicate_message",
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: 10,
SourceID: -1,
},
ShardName: "unit_test_replicate_message_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{10},
RowIDs: []int64{66},
NumRows: 1,
},
}
msgBytes, _ := insertMsg.Marshal(insertMsg)
replicateRequest := &milvuspb.ReplicateMessageRequest{
ChannelName: "unit_test_replicate_message",
BeginTs: 1,
EndTs: 10,
Msgs: [][]byte{msgBytes.([]byte)},
StartPositions: []*msgpb.MsgPosition{
{ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 1")},
},
EndPositions: []*msgpb.MsgPosition{
{ChannelName: "unit_test_replicate_message", MsgID: []byte("mock message id 2")},
},
}
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
assert.Equal(t, base64.StdEncoding.EncodeToString([]byte("mock message id 2")), resp.GetPosition())
res := resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
assert.NotNil(t, res)
time.Sleep(2 * time.Second)
{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(nil, errors.New("mock error: broadcast"))
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.NotEqualValues(t, 0, resp.GetStatus().GetCode())
resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
time.Sleep(2 * time.Second)
}
{
broadcastMock.Unset()
broadcastMock = msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"unit_test_replicate_message": {},
}, nil)
resp, err := node.ReplicateMessage(context.TODO(), replicateRequest)
assert.NoError(t, err)
assert.EqualValues(t, 0, resp.GetStatus().GetCode())
assert.Empty(t, resp.GetPosition())
resourceManager.Delete(ReplicateMsgStreamTyp, replicateRequest.GetChannelName())
time.Sleep(2 * time.Second)
broadcastMock.Unset()
}
})
}
func TestProxy_ImportV2(t *testing.T) {
wal := mock_streaming.NewMockWALAccesser(t)
b := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(b).Maybe()
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil).Maybe()
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
ctx := context.Background()
mockErr := errors.New("mock error")
@ -1661,7 +1358,7 @@ func TestProxy_ImportV2(t *testing.T) {
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator)
assert.NoError(t, err)
node.sched = scheduler
err = node.sched.Start()
@ -1916,333 +1613,6 @@ func TestRegisterRestRouter(t *testing.T) {
}
}
func TestReplicateMessageForCollectionMode(t *testing.T) {
paramtable.Init()
ctx := context.Background()
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 10,
EndTimestamp: 10,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: 10,
SourceID: -1,
},
ShardName: "foo_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{10},
RowIDs: []int64{66},
NumRows: 1,
},
}
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 20,
EndTimestamp: 20,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 10002,
Timestamp: 20,
SourceID: -1,
},
ShardName: "foo_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
},
}
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
t.Run("replicate message in the replicate collection mode", func(t *testing.T) {
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
{
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "false")
p := &Proxy{}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
})
assert.NoError(t, err)
assert.Error(t, merr.Error(r.Status))
}
{
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "false")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
p := &Proxy{}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
})
assert.NoError(t, err)
assert.Error(t, merr.Error(r.Status))
}
})
t.Run("replicate message for the replicate collection mode", func(t *testing.T) {
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Twice()
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Twice()
globalMetaCache = mockCache
{
p := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
Msgs: [][]byte{insertMsgBytes.([]byte)},
})
assert.NoError(t, err)
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
}
{
p := &Proxy{
replicateStreamManager: NewReplicateStreamManager(context.Background(), nil, nil),
}
p.UpdateStateCode(commonpb.StateCode_Healthy)
r, err := p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "foo",
Msgs: [][]byte{deleteMsgBytes.([]byte)},
})
assert.NoError(t, err)
assert.EqualValues(t, r.GetStatus().GetCode(), merr.Code(merr.ErrCollectionReplicateMode))
}
})
}
func TestAlterCollectionReplicateProperty(t *testing.T) {
paramtable.Init()
paramtable.Get().Save(paramtable.Get().CommonCfg.TTMsgEnabled.Key, "true")
paramtable.Get().Save(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key, "true")
defer func() {
paramtable.Get().Reset(paramtable.Get().CommonCfg.TTMsgEnabled.Key)
paramtable.Get().Reset(paramtable.Get().CommonCfg.CollectionReplicateEnable.Key)
}()
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
replicateID: "local-milvus",
}, nil).Maybe()
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Maybe()
mockCache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(&schemaInfo{}, nil)
globalMetaCache = mockCache
factory := newMockMsgStreamFactory()
msgStreamObj := msgstream.NewMockMsgStream(t)
msgStreamObj.EXPECT().SetRepackFunc(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().AsProducer(mock.Anything, mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().ForceEnableProduce(mock.Anything).Return().Maybe()
msgStreamObj.EXPECT().Close().Return().Maybe()
mockMsgID1 := mqcommon.NewMockMessageID(t)
mockMsgID2 := mqcommon.NewMockMessageID(t)
mockMsgID2.EXPECT().Serialize().Return([]byte("mock message id 2")).Maybe()
msgStreamObj.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(map[string][]mqcommon.MessageID{
"alter_property": {mockMsgID1, mockMsgID2},
}, nil).Maybe()
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return msgStreamObj, nil
}
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
ctx := context.Background()
var startTt uint64 = 10
startTime := time.Now()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(1000)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, func() Timestamp {
return Timestamp(time.Since(startTime).Seconds()) + startTt
})
assert.NoError(t, err)
segAllocator.Start()
mockMixcoord := mocks.NewMockMixCoordClient(t)
mockMixcoord.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *rootcoordpb.AllocTimestampRequest, option ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) {
return &rootcoordpb.AllocTimestampResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Timestamp: Timestamp(time.Since(startTime).Seconds()) + startTt,
}, nil
})
mockMixcoord.EXPECT().AlterCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}, nil)
p := &Proxy{
ctx: ctx,
replicateStreamManager: manager,
segAssigner: segAllocator,
mixCoord: mockMixcoord,
}
tsoAllocatorIns := newMockTsoAllocator()
p.sched, err = newTaskScheduler(p.ctx, tsoAllocatorIns, p.factory)
assert.NoError(t, err)
p.sched.Start()
defer p.sched.Close()
p.UpdateStateCode(commonpb.StateCode_Healthy)
getInsertMsgBytes := func(channel string, ts uint64) []byte {
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: channel,
MsgID: []byte("mock message id 2"),
},
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 10001,
Timestamp: ts,
SourceID: -1,
},
ShardName: channel + "_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
SegmentID: 33,
Timestamps: []uint64{ts},
RowIDs: []int64{66},
NumRows: 1,
},
}
insertMsgBytes, _ := insertMsg.Marshal(insertMsg)
return insertMsgBytes.([]byte)
}
getDeleteMsgBytes := func(channel string, ts uint64) []byte {
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
MsgPosition: &msgstream.MsgPosition{
ChannelName: "foo",
MsgID: []byte("mock message id 2"),
},
},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 10002,
Timestamp: ts,
SourceID: -1,
},
ShardName: channel + "_v1",
DbName: "default",
CollectionName: "foo_collection",
PartitionName: "_default",
DbID: 1,
CollectionID: 11,
PartitionID: 22,
},
}
deleteMsgBytes, _ := deleteMsg.Marshal(deleteMsg)
return deleteMsgBytes.([]byte)
}
go func() {
// replicate message
var replicateResp *milvuspb.ReplicateMessageResponse
var err error
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_1",
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+5)},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_2",
Msgs: [][]byte{getDeleteMsgBytes("alter_property_2", startTt+5)},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
time.Sleep(time.Second)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_1",
Msgs: [][]byte{getInsertMsgBytes("alter_property_1", startTt+10)},
})
assert.NoError(t, err)
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
replicateResp, err = p.ReplicateMessage(ctx, &milvuspb.ReplicateMessageRequest{
ChannelName: "alter_property_2",
Msgs: [][]byte{getInsertMsgBytes("alter_property_2", startTt+10)},
})
assert.NoError(t, err)
assert.False(t, merr.Ok(replicateResp.Status), replicateResp.Status.Reason)
}()
time.Sleep(200 * time.Millisecond)
// alter collection property
statusResp, err := p.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: "default",
CollectionName: "foo_collection",
Properties: []*commonpb.KeyValuePair{
{
Key: "replicate.endTS",
Value: "1",
},
},
})
assert.NoError(t, err)
assert.True(t, merr.Ok(statusResp))
}
func TestRunAnalyzer(t *testing.T) {
paramtable.Init()
ctx := context.Background()
@ -2254,7 +1624,7 @@ func TestRunAnalyzer(t *testing.T) {
p := &Proxy{}
tsoAllocatorIns := newMockTsoAllocator()
sched, err := newTaskScheduler(ctx, tsoAllocatorIns, p.factory)
sched, err := newTaskScheduler(ctx, tsoAllocatorIns)
require.NoError(t, err)
sched.Start()
defer sched.Close()

View File

@ -2,12 +2,7 @@
package proxy
import (
context "context"
msgstream "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
mock "github.com/stretchr/testify/mock"
)
import mock "github.com/stretchr/testify/mock"
// MockChannelsMgr is an autogenerated mock type for the channelsMgr type
type MockChannelsMgr struct {
@ -80,65 +75,6 @@ func (_c *MockChannelsMgr_getChannels_Call) RunAndReturn(run func(int64) ([]stri
return _c
}
// getOrCreateDmlStream provides a mock function with given fields: ctx, collectionID
func (_m *MockChannelsMgr) getOrCreateDmlStream(ctx context.Context, collectionID int64) (msgstream.MsgStream, error) {
ret := _m.Called(ctx, collectionID)
if len(ret) == 0 {
panic("no return value specified for getOrCreateDmlStream")
}
var r0 msgstream.MsgStream
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, int64) (msgstream.MsgStream, error)); ok {
return rf(ctx, collectionID)
}
if rf, ok := ret.Get(0).(func(context.Context, int64) msgstream.MsgStream); ok {
r0 = rf(ctx, collectionID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(msgstream.MsgStream)
}
}
if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok {
r1 = rf(ctx, collectionID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockChannelsMgr_getOrCreateDmlStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'getOrCreateDmlStream'
type MockChannelsMgr_getOrCreateDmlStream_Call struct {
*mock.Call
}
// getOrCreateDmlStream is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
func (_e *MockChannelsMgr_Expecter) getOrCreateDmlStream(ctx interface{}, collectionID interface{}) *MockChannelsMgr_getOrCreateDmlStream_Call {
return &MockChannelsMgr_getOrCreateDmlStream_Call{Call: _e.mock.On("getOrCreateDmlStream", ctx, collectionID)}
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Run(run func(ctx context.Context, collectionID int64)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) Return(_a0 msgstream.MsgStream, _a1 error) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockChannelsMgr_getOrCreateDmlStream_Call) RunAndReturn(run func(context.Context, int64) (msgstream.MsgStream, error)) *MockChannelsMgr_getOrCreateDmlStream_Call {
_c.Call.Return(run)
return _c
}
// getVChannels provides a mock function with given fields: collectionID
func (_m *MockChannelsMgr) getVChannels(collectionID int64) ([]string, error) {
ret := _m.Called(collectionID)
@ -197,38 +133,6 @@ func (_c *MockChannelsMgr_getVChannels_Call) RunAndReturn(run func(int64) ([]str
return _c
}
// removeAllDMLStream provides a mock function with no fields
func (_m *MockChannelsMgr) removeAllDMLStream() {
_m.Called()
}
// MockChannelsMgr_removeAllDMLStream_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'removeAllDMLStream'
type MockChannelsMgr_removeAllDMLStream_Call struct {
*mock.Call
}
// removeAllDMLStream is a helper method to define mock.On call
func (_e *MockChannelsMgr_Expecter) removeAllDMLStream() *MockChannelsMgr_removeAllDMLStream_Call {
return &MockChannelsMgr_removeAllDMLStream_Call{Call: _e.mock.On("removeAllDMLStream")}
}
func (_c *MockChannelsMgr_removeAllDMLStream_Call) Run(run func()) *MockChannelsMgr_removeAllDMLStream_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockChannelsMgr_removeAllDMLStream_Call) Return() *MockChannelsMgr_removeAllDMLStream_Call {
_c.Call.Return()
return _c
}
func (_c *MockChannelsMgr_removeAllDMLStream_Call) RunAndReturn(run func()) *MockChannelsMgr_removeAllDMLStream_Call {
_c.Run(run)
return _c
}
// removeDMLStream provides a mock function with given fields: collectionID
func (_m *MockChannelsMgr) removeDMLStream(collectionID int64) {
_m.Called(collectionID)

View File

@ -18,23 +18,12 @@ package proxy
import (
"context"
"strconv"
"time"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -103,197 +92,3 @@ func genInsertMsgsByPartition(ctx context.Context,
return repackedMsgs, nil
}
func repackInsertDataByPartition(ctx context.Context,
partitionName string,
rowOffsets []int,
channelName string,
insertMsg *msgstream.InsertMsg,
segIDAssigner *segIDAssigner,
) ([]msgstream.TsMsg, error) {
res := make([]msgstream.TsMsg, 0)
maxTs := Timestamp(0)
for _, offset := range rowOffsets {
ts := insertMsg.Timestamps[offset]
if maxTs < ts {
maxTs = ts
}
}
partitionID, err := globalMetaCache.GetPartitionID(ctx, insertMsg.GetDbName(), insertMsg.CollectionName, partitionName)
if err != nil {
return nil, err
}
beforeAssign := time.Now()
assignedSegmentInfos, err := segIDAssigner.GetSegmentID(insertMsg.CollectionID, partitionID, channelName, uint32(len(rowOffsets)), maxTs)
metrics.ProxyAssignSegmentIDLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(time.Since(beforeAssign).Milliseconds()))
if err != nil {
log.Ctx(ctx).Error("allocate segmentID for insert data failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("channelName", channelName),
zap.Int("allocate count", len(rowOffsets)),
zap.Error(err))
return nil, err
}
startPos := 0
for segmentID, count := range assignedSegmentInfos {
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
msgs, err := genInsertMsgsByPartition(ctx, segmentID, partitionID, partitionName, subRowOffsets, channelName, insertMsg)
if err != nil {
log.Ctx(ctx).Warn("repack insert data to insert msgs failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.Int64("partitionID", partitionID),
zap.Error(err))
return nil, err
}
res = append(res, msgs...)
startPos += int(count)
}
return res, nil
}
func setMsgID(ctx context.Context,
msgs []msgstream.TsMsg,
idAllocator *allocator.IDAllocator,
) error {
var idBegin int64
var err error
err = retry.Do(ctx, func() error {
idBegin, _, err = idAllocator.Alloc(uint32(len(msgs)))
return err
})
if err != nil {
log.Ctx(ctx).Error("failed to allocate msg id", zap.Error(err))
return err
}
for i, msg := range msgs {
msg.SetID(idBegin + UniqueID(i))
}
return nil
}
func repackInsertData(ctx context.Context,
channelNames []string,
insertMsg *msgstream.InsertMsg,
result *milvuspb.MutationResult,
idAllocator *allocator.IDAllocator,
segIDAssigner *segIDAssigner,
) (*msgstream.MsgPack, error) {
msgPack := &msgstream.MsgPack{
BeginTs: insertMsg.BeginTs(),
EndTs: insertMsg.EndTs(),
}
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
for channel, rowOffsets := range channel2RowOffsets {
partitionName := insertMsg.PartitionName
msgs, err := repackInsertDataByPartition(ctx, partitionName, rowOffsets, channel, insertMsg, segIDAssigner)
if err != nil {
log.Ctx(ctx).Warn("repack insert data to msg pack failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("partition name", partitionName),
zap.Error(err))
return nil, err
}
msgPack.Msgs = append(msgPack.Msgs, msgs...)
}
err := setMsgID(ctx, msgPack.Msgs, idAllocator)
if err != nil {
log.Ctx(ctx).Error("failed to set msgID when repack insert data",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("partition name", insertMsg.PartitionName),
zap.Error(err))
return nil, err
}
return msgPack, nil
}
func repackInsertDataWithPartitionKey(ctx context.Context,
channelNames []string,
partitionKeys *schemapb.FieldData,
insertMsg *msgstream.InsertMsg,
result *milvuspb.MutationResult,
idAllocator *allocator.IDAllocator,
segIDAssigner *segIDAssigner,
) (*msgstream.MsgPack, error) {
msgPack := &msgstream.MsgPack{
BeginTs: insertMsg.BeginTs(),
EndTs: insertMsg.EndTs(),
}
channel2RowOffsets := assignChannelsByPK(result.IDs, channelNames, insertMsg)
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, insertMsg.GetDbName(), insertMsg.CollectionName)
if err != nil {
log.Ctx(ctx).Warn("get default partition names failed in partition key mode",
zap.String("collectionName", insertMsg.CollectionName),
zap.Error(err))
return nil, err
}
hashValues, err := typeutil.HashKey2Partitions(partitionKeys, partitionNames)
if err != nil {
log.Ctx(ctx).Warn("has partition keys to partitions failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.Error(err))
return nil, err
}
for channel, rowOffsets := range channel2RowOffsets {
partition2RowOffsets := make(map[string][]int)
for _, idx := range rowOffsets {
partitionName := partitionNames[hashValues[idx]]
if _, ok := partition2RowOffsets[partitionName]; !ok {
partition2RowOffsets[partitionName] = []int{}
}
partition2RowOffsets[partitionName] = append(partition2RowOffsets[partitionName], idx)
}
errGroup, _ := errgroup.WithContext(ctx)
partition2Msgs := typeutil.NewConcurrentMap[string, []msgstream.TsMsg]()
for partitionName, offsets := range partition2RowOffsets {
partitionName := partitionName
offsets := offsets
errGroup.Go(func() error {
msgs, err := repackInsertDataByPartition(ctx, partitionName, offsets, channel, insertMsg, segIDAssigner)
if err != nil {
return err
}
partition2Msgs.Insert(partitionName, msgs)
return nil
})
}
err = errGroup.Wait()
if err != nil {
log.Ctx(ctx).Warn("repack insert data into insert msg pack failed",
zap.String("collectionName", insertMsg.CollectionName),
zap.String("channelName", channel),
zap.Error(err))
return nil, err
}
partition2Msgs.Range(func(name string, msgs []msgstream.TsMsg) bool {
msgPack.Msgs = append(msgPack.Msgs, msgs...)
return true
})
}
err = setMsgID(ctx, msgPack.Msgs, idAllocator)
if err != nil {
log.Ctx(ctx).Error("failed to set msgID when repack insert data",
zap.String("collectionName", insertMsg.CollectionName),
zap.Error(err))
return nil, err
}
return msgPack, nil
}

View File

@ -21,7 +21,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -49,12 +48,6 @@ func TestRepackInsertData(t *testing.T) {
defer mix.Close()
cache := NewMockCache(t)
cache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(int64(1), nil)
globalMetaCache = cache
idAllocator, err := allocator.NewIDAllocator(ctx, mix, paramtable.GetNodeID())
@ -113,33 +106,6 @@ func TestRepackInsertData(t *testing.T) {
for index := range insertMsg.RowIDs {
insertMsg.RowIDs[index] = int64(index)
}
ids, err := parsePrimaryFieldData2IDs(fieldData)
assert.NoError(t, err)
result := &milvuspb.MutationResult{
IDs: ids,
}
t.Run("assign segmentID failed", func(t *testing.T) {
fakeSegAllocator, err := newSegIDAssigner(ctx, &mockDataCoord2{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = fakeSegAllocator.Start()
defer fakeSegAllocator.Close()
_, err = repackInsertData(ctx, []string{"test_dml_channel"}, insertMsg,
result, idAllocator, fakeSegAllocator)
assert.Error(t, err)
})
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("repack insert data success", func(t *testing.T) {
_, err = repackInsertData(ctx, []string{"test_dml_channel"}, insertMsg, result, idAllocator, segAllocator)
assert.NoError(t, err)
})
}
func TestRepackInsertDataWithPartitionKey(t *testing.T) {
@ -161,11 +127,6 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = segAllocator.Start()
defer segAllocator.Close()
fieldName2Types := map[string]schemapb.DataType{
testInt64Field: schemapb.DataType_Int64,
testVarCharField: schemapb.DataType_VarChar,
@ -221,17 +182,4 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
for index := range insertMsg.RowIDs {
insertMsg.RowIDs[index] = int64(index)
}
ids, err := parsePrimaryFieldData2IDs(fieldNameToDatas[testInt64Field])
assert.NoError(t, err)
result := &milvuspb.MutationResult{
IDs: ids,
}
t.Run("repack insert data success", func(t *testing.T) {
partitionKeys := generateFieldData(schemapb.DataType_VarChar, testVarCharField, nb)
_, err = repackInsertDataWithPartitionKey(ctx, []string{"test_dml_channel"}, partitionKeys,
insertMsg, result, idAllocator, segAllocator)
assert.NoError(t, err)
})
}

View File

@ -21,14 +21,12 @@ import (
"fmt"
"math/rand"
"os"
"strconv"
"sync"
"time"
"github.com/cockroachdb/errors"
"github.com/google/uuid"
"github.com/hashicorp/golang-lru/v2/expirable"
clientv3 "go.etcd.io/etcd/client/v3"
"go.uber.org/atomic"
"go.uber.org/zap"
@ -40,19 +38,15 @@ import (
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/expr"
"github.com/milvus-io/milvus/pkg/v2/util/logutil"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/ratelimitutil"
"github.com/milvus-io/milvus/pkg/v2/util/resource"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -87,7 +81,6 @@ type Proxy struct {
stateCode atomic.Int32
etcdCli *clientv3.Client
address string
mixCoord types.MixCoordClient
@ -95,23 +88,16 @@ type Proxy struct {
chMgr channelsMgr
replicateMsgStream msgstream.MsgStream
sched *taskScheduler
chTicker channelsTimeTicker
rowIDAllocator *allocator.IDAllocator
tsoAllocator *timestampAllocator
segAssigner *segIDAssigner
metricsCacheManager *metricsinfo.MetricsCacheManager
session *sessionutil.Session
shardMgr shardClientMgr
factory dependency.Factory
searchResultCh chan *internalpb.SearchResults
// Add callback functions at different stages
@ -122,8 +108,7 @@ type Proxy struct {
lbPolicy LBPolicy
// resource manager
resourceManager resource.Manager
replicateStreamManager *ReplicateStreamManager
resourceManager resource.Manager
// materialized view
enableMaterializedView bool
@ -135,7 +120,7 @@ type Proxy struct {
}
// NewProxy returns a Proxy struct.
func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
func NewProxy(ctx context.Context, _ dependency.Factory) (*Proxy, error) {
rand.Seed(time.Now().UnixNano())
ctx1, cancel := context.WithCancel(ctx)
n := 1024 // better to be configurable
@ -143,18 +128,15 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
lbPolicy := NewLBPolicyImpl(mgr)
lbPolicy.Start(ctx)
resourceManager := resource.NewManager(10*time.Second, 20*time.Second, make(map[string]time.Duration))
replicateStreamManager := NewReplicateStreamManager(ctx, factory, resourceManager)
node := &Proxy{
ctx: ctx1,
cancel: cancel,
factory: factory,
searchResultCh: make(chan *internalpb.SearchResults, n),
shardMgr: mgr,
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
lbPolicy: lbPolicy,
resourceManager: resourceManager,
replicateStreamManager: replicateStreamManager,
slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
ctx: ctx1,
cancel: cancel,
searchResultCh: make(chan *internalpb.SearchResults, n),
shardMgr: mgr,
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
lbPolicy: lbPolicy,
resourceManager: resourceManager,
slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
}
node.UpdateStateCode(commonpb.StateCode_Abnormal)
expr.Register("proxy", node)
@ -223,10 +205,6 @@ func (node *Proxy) Init() error {
}
log.Info("init session for Proxy done")
node.factory.Init(Params)
log.Debug("init access log for Proxy done")
err := node.initRateCollector()
if err != nil {
return err
@ -253,44 +231,18 @@ func (node *Proxy) Init() error {
node.tsoAllocator = tsoAllocator
log.Debug("create timestamp allocator done", zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()))
segAssigner, err := newSegIDAssigner(node.ctx, node.mixCoord, node.lastTick)
if err != nil {
log.Warn("failed to create segment id assigner",
zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()),
zap.Error(err))
return err
}
node.segAssigner = segAssigner
node.segAssigner.PeerID = paramtable.GetNodeID()
log.Debug("create segment id assigner done", zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()))
dmlChannelsFunc := getDmlChannelsFunc(node.ctx, node.mixCoord)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc, node.factory)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, defaultInsertRepackFunc)
node.chMgr = chMgr
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))
replicateMsgChannel := Params.CommonCfg.ReplicateMsgChannel.GetValue()
node.replicateMsgStream, err = node.factory.NewMsgStream(node.ctx)
if err != nil {
log.Warn("failed to create replicate msg stream",
zap.String("role", typeutil.ProxyRole), zap.Int64("ProxyID", paramtable.GetNodeID()),
zap.Error(err))
return err
}
node.replicateMsgStream.ForceEnableProduce(true)
node.replicateMsgStream.AsProducer(node.ctx, []string{replicateMsgChannel})
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator, node.factory)
node.sched, err = newTaskScheduler(node.ctx, node.tsoAllocator)
if err != nil {
log.Warn("failed to create task scheduler", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err
}
log.Debug("create task scheduler done", zap.String("role", typeutil.ProxyRole))
syncTimeTickInterval := Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond) / 2
node.chTicker = newChannelsTimeTicker(node.ctx, Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond)/2, []string{}, node.sched.getPChanStatistics, tsoAllocator)
log.Debug("create channels time ticker done", zap.String("role", typeutil.ProxyRole), zap.Duration("syncTimeTickInterval", syncTimeTickInterval))
node.enableComplexDeleteLimit = Params.QuotaConfig.ComplexDeleteLimitEnable.GetAsBool()
node.metricsCacheManager = metricsinfo.NewMetricsCacheManager()
log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole))
@ -314,90 +266,6 @@ func (node *Proxy) Init() error {
return nil
}
// sendChannelsTimeTickLoop starts a goroutine that synchronizes the time tick information.
func (node *Proxy) sendChannelsTimeTickLoop() {
log := log.Ctx(node.ctx)
node.wg.Add(1)
go func() {
defer node.wg.Done()
ticker := time.NewTicker(Params.ProxyCfg.TimeTickInterval.GetAsDuration(time.Millisecond))
defer ticker.Stop()
for {
select {
case <-node.ctx.Done():
log.Info("send channels time tick loop exit")
return
case <-ticker.C:
if !Params.CommonCfg.TTMsgEnabled.GetAsBool() {
continue
}
stats, ts, err := node.chTicker.getMinTsStatistics()
if err != nil {
log.Warn("sendChannelsTimeTickLoop.getMinTsStatistics", zap.Error(err))
continue
}
if ts == 0 {
log.Warn("sendChannelsTimeTickLoop.getMinTsStatistics default timestamp equal 0")
continue
}
channels := make([]pChan, 0, len(stats))
tss := make([]Timestamp, 0, len(stats))
maxTs := ts
for channel, ts := range stats {
channels = append(channels, channel)
tss = append(tss, ts)
if ts > maxTs {
maxTs = ts
}
}
req := &internalpb.ChannelTimeTickMsg{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_TimeTick),
commonpbutil.WithSourceID(node.session.ServerID),
),
ChannelNames: channels,
Timestamps: tss,
DefaultTimestamp: maxTs,
}
func() {
// we should pay more attention to the max lag.
minTs := maxTs
minTsOfChannel := "default"
// find the min ts and the related channel.
for channel, ts := range stats {
if ts < minTs {
minTs = ts
minTsOfChannel = channel
}
}
sub := tsoutil.SubByNow(minTs)
metrics.ProxySyncTimeTickLag.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), minTsOfChannel).Set(float64(sub))
}()
status, err := node.mixCoord.UpdateChannelTimeTick(node.ctx, req)
if err != nil {
log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick", zap.Error(err))
continue
}
if status.GetErrorCode() != 0 {
log.Warn("sendChannelsTimeTickLoop.UpdateChannelTimeTick",
zap.Any("ErrorCode", status.ErrorCode),
zap.Any("Reason", status.Reason))
continue
}
}
}
}()
}
// Start starts a proxy node.
func (node *Proxy) Start() error {
log := log.Ctx(node.ctx)
@ -417,22 +285,6 @@ func (node *Proxy) Start() error {
}
log.Debug("start id allocator done", zap.String("role", typeutil.ProxyRole))
if !streamingutil.IsStreamingServiceEnabled() {
if err := node.segAssigner.Start(); err != nil {
log.Warn("failed to start segment id assigner", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err
}
log.Debug("start segment id assigner done", zap.String("role", typeutil.ProxyRole))
if err := node.chTicker.start(); err != nil {
log.Warn("failed to start channels time ticker", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err
}
log.Debug("start channels time ticker done", zap.String("role", typeutil.ProxyRole))
node.sendChannelsTimeTickLoop()
}
// Start callbacks
for _, cb := range node.startCallbacks {
cb()
@ -465,21 +317,6 @@ func (node *Proxy) Stop() error {
log.Info("close scheduler", zap.String("role", typeutil.ProxyRole))
}
if !streamingutil.IsStreamingServiceEnabled() {
if node.segAssigner != nil {
node.segAssigner.Close()
log.Info("close segment id assigner", zap.String("role", typeutil.ProxyRole))
}
if node.chTicker != nil {
err := node.chTicker.close()
if err != nil {
return err
}
log.Info("close channels time ticker", zap.String("role", typeutil.ProxyRole))
}
}
for _, cb := range node.closeCallbacks {
cb()
}
@ -492,10 +329,6 @@ func (node *Proxy) Stop() error {
node.shardMgr.Close()
}
if node.chMgr != nil {
node.chMgr.removeAllDMLStream()
}
if node.lbPolicy != nil {
node.lbPolicy.Close()
}
@ -519,11 +352,6 @@ func (node *Proxy) AddStartCallback(callbacks ...func()) {
node.startCallbacks = append(node.startCallbacks, callbacks...)
}
// lastTick returns the last write timestamp of all pchans in this Proxy.
func (node *Proxy) lastTick() Timestamp {
return node.chTicker.getMinTick()
}
// AddCloseCallback adds a callback in the Close phase.
func (node *Proxy) AddCloseCallback(callbacks ...func()) {
node.closeCallbacks = append(node.closeCallbacks, callbacks...)
@ -537,11 +365,6 @@ func (node *Proxy) GetAddress() string {
return node.address
}
// SetEtcdClient sets etcd client for proxy.
func (node *Proxy) SetEtcdClient(client *clientv3.Client) {
node.etcdCli = client
}
// SetMixCoordClient sets MixCoord client for proxy.
func (node *Proxy) SetMixCoordClient(cli types.MixCoordClient) {
node.mixCoord = cli

View File

@ -23,13 +23,11 @@ import (
"fmt"
"math/rand"
"net"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/blang/semver/v4"
"github.com/cockroachdb/errors"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/prometheus/client_golang/prometheus"
@ -50,12 +48,13 @@ import (
mixc "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
"github.com/milvus-io/milvus/internal/distributed/streaming"
grpcstreamingnode "github.com/milvus-io/milvus/internal/distributed/streamingnode"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -64,7 +63,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
@ -89,7 +87,6 @@ func init() {
Registry = prometheus.NewRegistry()
Registry.MustRegister(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{}))
Registry.MustRegister(prometheus.NewGoCollector())
common.Version = semver.MustParse("2.5.9")
}
func runMixCoord(ctx context.Context, localMsg bool) *grpcmixcoord.Server {
@ -119,6 +116,33 @@ func runMixCoord(ctx context.Context, localMsg bool) *grpcmixcoord.Server {
return rc
}
func runStreamingNode(ctx context.Context, localMsg bool, alias string) *grpcstreamingnode.Server {
var sn *grpcstreamingnode.Server
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
factory := dependency.MockDefaultFactory(localMsg, Params)
var err error
sn, err = grpcstreamingnode.NewServer(ctx, factory)
if err != nil {
panic(err)
}
if err = sn.Prepare(); err != nil {
panic(err)
}
err = sn.Run()
if err != nil {
panic(err)
}
}()
wg.Wait()
metrics.RegisterStreamingNode(Registry)
return sn
}
func runQueryNode(ctx context.Context, localMsg bool, alias string) *grpcquerynode.Server {
var qn *grpcquerynode.Server
var wg sync.WaitGroup
@ -276,17 +300,9 @@ func TestProxy(t *testing.T) {
paramtable.Init()
params := paramtable.Get()
testutil.ResetEnvironment()
wal := mock_streaming.NewMockWALAccesser(t)
b := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(b).Maybe()
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil).Maybe()
local := mock_streaming.NewMockLocal(t)
local.EXPECT().GetLatestMVCCTimestampIfLocal(mock.Anything, mock.Anything).Return(0, nil).Maybe()
local.EXPECT().GetMetricsIfLocal(mock.Anything).Return(&types.StreamingNodeMetrics{}, nil).Maybe()
wal.EXPECT().Local().Return(local).Maybe()
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
paramtable.SetLocalComponentEnabled(typeutil.StreamingNodeRole)
streamingutil.SetStreamingServiceEnabled()
defer streamingutil.UnsetStreamingServiceEnabled()
params.RootCoordGrpcServerCfg.IP = "localhost"
params.QueryCoordGrpcServerCfg.IP = "localhost"
@ -295,15 +311,14 @@ func TestProxy(t *testing.T) {
params.QueryNodeGrpcServerCfg.IP = "localhost"
params.DataNodeGrpcServerCfg.IP = "localhost"
params.StreamingNodeGrpcServerCfg.IP = "localhost"
path := "/tmp/milvus/rocksmq" + funcutil.GenRandomStr()
t.Setenv("ROCKSMQ_PATH", path)
defer os.RemoveAll(path)
params.Save(params.MQCfg.Type.Key, "pulsar")
params.CommonCfg.EnableStorageV2.SwapTempValue("false")
defer params.CommonCfg.EnableStorageV2.SwapTempValue("")
ctx, cancel := context.WithCancel(context.Background())
ctx = GetContext(ctx, "root:123456")
localMsg := true
factory := dependency.MockDefaultFactory(localMsg, Params)
factory := dependency.NewDefaultFactory(false)
alias := "TestProxy"
log.Info("Initialize parameter table of Proxy")
@ -314,11 +329,16 @@ func TestProxy(t *testing.T) {
dn := runDataNode(ctx, localMsg, alias)
log.Info("running DataNode ...")
sn := runStreamingNode(ctx, localMsg, alias)
log.Info("running StreamingNode ...")
qn := runQueryNode(ctx, localMsg, alias)
log.Info("running QueryNode ...")
time.Sleep(10 * time.Millisecond)
streaming.Init()
proxy, err := NewProxy(ctx, factory)
assert.NoError(t, err)
assert.NotNil(t, proxy)
@ -331,9 +351,11 @@ func TestProxy(t *testing.T) {
Params.EtcdCfg.EtcdTLSKey.GetValue(),
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
if err != nil {
panic(err)
}
defer etcdcli.Close()
assert.NoError(t, err)
proxy.SetEtcdClient(etcdcli)
testServer := newProxyTestServer(proxy)
wg.Add(1)
@ -372,7 +394,7 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
log.Info("Register proxy done")
defer func() {
a := []any{mix, qn, dn, proxy}
a := []any{mix, qn, dn, sn, proxy}
fmt.Println(len(a))
// HINT: the order of stopping service refers to the `roles.go` file
log.Info("start to stop the services")
@ -1106,6 +1128,10 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
segmentIDs = resp.CollSegIDs[collectionName].Data
// TODO: Here's a Bug, because a growing segment may cannot be seen right away by mixcoord,
// it can only be seen by streamingnode right away, so we need to check the flush state at streamingnode but not here.
// use timetick for GetFlushState in-future but not segment list.
time.Sleep(5 * time.Second)
log.Info("flush collection", zap.Int64s("segments to be flushed", segmentIDs))
f := func() bool {
@ -4792,11 +4818,7 @@ func TestProxy_Import(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
wal := mock_streaming.NewMockWALAccesser(t)
b := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(b).Maybe()
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
streaming.SetupNoopWALForTest()
t.Run("Import failed", func(t *testing.T) {
proxy := &Proxy{}
@ -4825,7 +4847,6 @@ func TestProxy_Import(t *testing.T) {
chMgr.EXPECT().getVChannels(mock.Anything).Return([]string{"foo"}, nil)
proxy.chMgr = chMgr
factory := dependency.NewDefaultFactory(true)
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
ID: rand.Int63(),
@ -4839,19 +4860,12 @@ func TestProxy_Import(t *testing.T) {
proxy.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
scheduler, err := newTaskScheduler(ctx, proxy.tsoAllocator, factory)
scheduler, err := newTaskScheduler(ctx, proxy.tsoAllocator)
assert.NoError(t, err)
proxy.sched = scheduler
err = proxy.sched.Start()
assert.NoError(t, err)
wal := mock_streaming.NewMockWALAccesser(t)
b := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(b)
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil)
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
req := &milvuspb.ImportRequest{
CollectionName: "dummy",
Files: []string{"a.json"},

View File

@ -1,72 +0,0 @@
package proxy
import (
"context"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/resource"
)
const (
ReplicateMsgStreamTyp = "replicate_msg_stream"
ReplicateMsgStreamExpireTime = 30 * time.Second
)
type ReplicateStreamManager struct {
ctx context.Context
factory msgstream.Factory
dispatcher msgstream.UnmarshalDispatcher
resourceManager resource.Manager
}
func NewReplicateStreamManager(ctx context.Context, factory msgstream.Factory, resourceManager resource.Manager) *ReplicateStreamManager {
manager := &ReplicateStreamManager{
ctx: ctx,
factory: factory,
dispatcher: (&msgstream.ProtoUDFactory{}).NewUnmarshalDispatcher(),
resourceManager: resourceManager,
}
return manager
}
func (m *ReplicateStreamManager) newMsgStreamResource(ctx context.Context, channel string) resource.NewResourceFunc {
return func() (resource.Resource, error) {
msgStream, err := m.factory.NewMsgStream(ctx)
if err != nil {
log.Ctx(m.ctx).Warn("failed to create msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err
}
msgStream.SetRepackFunc(replicatePackFunc)
msgStream.AsProducer(ctx, []string{channel})
msgStream.ForceEnableProduce(true)
res := resource.NewSimpleResource(msgStream, ReplicateMsgStreamTyp, channel, ReplicateMsgStreamExpireTime, func() {
msgStream.Close()
})
return res, nil
}
}
func (m *ReplicateStreamManager) GetReplicateMsgStream(ctx context.Context, channel string) (msgstream.MsgStream, error) {
ctxLog := log.Ctx(ctx).With(zap.String("proxy_channel", channel))
res, err := m.resourceManager.Get(ReplicateMsgStreamTyp, channel, m.newMsgStreamResource(ctx, channel))
if err != nil {
ctxLog.Warn("failed to get replicate msg stream", zap.String("channel", channel), zap.Error(err))
return nil, err
}
if obj, ok := res.Get().(msgstream.MsgStream); ok && obj != nil {
return obj, nil
}
ctxLog.Warn("invalid resource object", zap.Any("obj", res.Get()))
return nil, merr.ErrInvalidStreamObj
}
func (m *ReplicateStreamManager) GetMsgDispatcher() msgstream.UnmarshalDispatcher {
return m.dispatcher
}

View File

@ -1,85 +0,0 @@
package proxy
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/resource"
)
func TestReplicateManager(t *testing.T) {
factory := newMockMsgStreamFactory()
resourceManager := resource.NewManager(time.Second, 2*time.Second, nil)
manager := NewReplicateStreamManager(context.Background(), factory, resourceManager)
{
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return nil, errors.New("mock msgstream fail")
}
_, err := manager.GetReplicateMsgStream(context.Background(), "test")
assert.Error(t, err)
}
{
mockMsgStream := newMockMsgStream()
i := 0
mockMsgStream.setRepack = func(repackFunc msgstream.RepackFunc) {
i++
}
mockMsgStream.asProducer = func(producers []string) {
i++
}
mockMsgStream.forceEnableProduce = func(b bool) {
i++
}
mockMsgStream.close = func() {
i++
}
factory.f = func(ctx context.Context) (msgstream.MsgStream, error) {
return mockMsgStream, nil
}
_, err := manager.GetReplicateMsgStream(context.Background(), "test")
assert.NoError(t, err)
assert.Equal(t, 3, i)
time.Sleep(time.Second)
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
assert.NoError(t, err)
assert.Equal(t, 3, i)
res := resourceManager.Delete(ReplicateMsgStreamTyp, "test")
assert.NotNil(t, res)
assert.Eventually(t, func() bool {
return resourceManager.Delete(ReplicateMsgStreamTyp, "test") == nil
}, time.Second*4, time.Millisecond*500)
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
assert.NoError(t, err)
assert.Equal(t, 7, i)
}
{
res := resourceManager.Delete(ReplicateMsgStreamTyp, "test")
assert.NotNil(t, res)
assert.Eventually(t, func() bool {
return resourceManager.Delete(ReplicateMsgStreamTyp, "test") == nil
}, time.Second*4, time.Millisecond*500)
res, err := resourceManager.Get(ReplicateMsgStreamTyp, "test", func() (resource.Resource, error) {
return resource.NewResource(resource.WithObj("str")), nil
})
assert.NoError(t, err)
assert.Equal(t, "str", res.Get())
_, err = manager.GetReplicateMsgStream(context.Background(), "test")
assert.ErrorIs(t, err, merr.ErrInvalidStreamObj)
}
{
assert.NotNil(t, manager.GetMsgDispatcher())
}
}

View File

@ -1,401 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"container/list"
"context"
"fmt"
"strconv"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const (
segCountPerRPC = 20000
)
// DataCoord is a narrowed interface of DataCoordinator which only provide AssignSegmentID method
type DataCoord interface {
AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error)
}
type segRequest struct {
allocator.BaseRequest
count uint32
collID UniqueID
partitionID UniqueID
segInfo map[UniqueID]uint32
channelName string
timestamp Timestamp
}
type segInfo struct {
segID UniqueID
count uint32
expireTime Timestamp
}
type assignInfo struct {
collID UniqueID
partitionID UniqueID
channelName string
segInfos *list.List
lastInsertTime time.Time
}
func (info *segInfo) IsExpired(ts Timestamp) bool {
return ts > info.expireTime || info.count <= 0
}
func (info *segInfo) Capacity(ts Timestamp) uint32 {
if info.IsExpired(ts) {
return 0
}
return info.count
}
func (info *segInfo) Assign(ts Timestamp, count uint32) uint32 {
if info.IsExpired(ts) {
log.Ctx(context.TODO()).Debug("segInfo Assign IsExpired", zap.Uint64("ts", ts),
zap.Uint32("count", count))
return 0
}
ret := uint32(0)
if info.count >= count {
info.count -= count
ret = count
} else {
ret = info.count
info.count = 0
}
return ret
}
func (info *assignInfo) RemoveExpired(ts Timestamp) {
var next *list.Element
for e := info.segInfos.Front(); e != nil; e = next {
next = e.Next()
segInfo, ok := e.Value.(*segInfo)
if !ok {
log.Warn("can not cast to segInfo")
continue
}
if segInfo.IsExpired(ts) {
info.segInfos.Remove(e)
}
}
}
func (info *assignInfo) Capacity(ts Timestamp) uint32 {
ret := uint32(0)
for e := info.segInfos.Front(); e != nil; e = e.Next() {
segInfo := e.Value.(*segInfo)
ret += segInfo.Capacity(ts)
}
return ret
}
func (info *assignInfo) Assign(ts Timestamp, count uint32) (map[UniqueID]uint32, error) {
capacity := info.Capacity(ts)
if capacity < count {
errMsg := fmt.Sprintf("AssignSegment Failed: capacity:%d is less than count:%d", capacity, count)
return nil, errors.New(errMsg)
}
result := make(map[UniqueID]uint32)
for e := info.segInfos.Front(); e != nil && count != 0; e = e.Next() {
segInfo := e.Value.(*segInfo)
cur := segInfo.Assign(ts, count)
count -= cur
if cur > 0 {
result[segInfo.segID] += cur
}
}
return result, nil
}
type segIDAssigner struct {
allocator.CachedAllocator
assignInfos map[UniqueID]*list.List // collectionID -> *list.List
segReqs []*datapb.SegmentIDRequest
getTickFunc func() Timestamp
PeerID UniqueID
dataCoord DataCoord
countPerRPC uint32
}
// newSegIDAssigner creates a new segIDAssigner
func newSegIDAssigner(ctx context.Context, dataCoord DataCoord, getTickFunc func() Timestamp) (*segIDAssigner, error) {
ctx1, cancel := context.WithCancel(ctx)
sa := &segIDAssigner{
CachedAllocator: allocator.CachedAllocator{
Ctx: ctx1,
CancelFunc: cancel,
Role: "SegmentIDAllocator",
},
countPerRPC: segCountPerRPC,
dataCoord: dataCoord,
assignInfos: make(map[UniqueID]*list.List),
getTickFunc: getTickFunc,
}
sa.TChan = &allocator.Ticker{
UpdateInterval: time.Second,
}
sa.CachedAllocator.SyncFunc = sa.syncSegments
sa.CachedAllocator.ProcessFunc = sa.processFunc
sa.CachedAllocator.CheckSyncFunc = sa.checkSyncFunc
sa.CachedAllocator.PickCanDoFunc = sa.pickCanDoFunc
sa.Init()
return sa, nil
}
func (sa *segIDAssigner) collectExpired() {
ts := sa.getTickFunc()
var next *list.Element
for _, info := range sa.assignInfos {
for e := info.Front(); e != nil; e = next {
next = e.Next()
assign := e.Value.(*assignInfo)
assign.RemoveExpired(ts)
if assign.Capacity(ts) == 0 {
info.Remove(e)
}
}
}
}
func (sa *segIDAssigner) pickCanDoFunc() {
if sa.ToDoReqs == nil {
return
}
records := make(map[UniqueID]map[UniqueID]map[string]uint32)
var newTodoReqs []allocator.Request
for _, req := range sa.ToDoReqs {
segRequest := req.(*segRequest)
collID := segRequest.collID
partitionID := segRequest.partitionID
channelName := segRequest.channelName
if _, ok := records[collID]; !ok {
records[collID] = make(map[UniqueID]map[string]uint32)
}
if _, ok := records[collID][partitionID]; !ok {
records[collID][partitionID] = make(map[string]uint32)
}
if _, ok := records[collID][partitionID][channelName]; !ok {
records[collID][partitionID][channelName] = 0
}
records[collID][partitionID][channelName] += segRequest.count
assign, err := sa.getAssign(segRequest.collID, segRequest.partitionID, segRequest.channelName)
if err != nil || assign.Capacity(segRequest.timestamp) < records[collID][partitionID][channelName] {
sa.segReqs = append(sa.segReqs, &datapb.SegmentIDRequest{
ChannelName: channelName,
Count: segRequest.count,
CollectionID: collID,
PartitionID: partitionID,
})
newTodoReqs = append(newTodoReqs, req)
} else {
sa.CanDoReqs = append(sa.CanDoReqs, req)
}
}
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner pickCanDoFunc", zap.Any("records", records),
zap.Int("len(newTodoReqs)", len(newTodoReqs)),
zap.Int("len(CanDoReqs)", len(sa.CanDoReqs)))
sa.ToDoReqs = newTodoReqs
}
func (sa *segIDAssigner) getAssign(collID UniqueID, partitionID UniqueID, channelName string) (*assignInfo, error) {
assignInfos, ok := sa.assignInfos[collID]
if !ok {
return nil, fmt.Errorf("can not find collection %d", collID)
}
for e := assignInfos.Front(); e != nil; e = e.Next() {
info := e.Value.(*assignInfo)
if info.partitionID != partitionID || info.channelName != channelName {
continue
}
return info, nil
}
return nil, fmt.Errorf("can not find assign info with collID %d, partitionID %d, channelName %s",
collID, partitionID, channelName)
}
func (sa *segIDAssigner) checkSyncFunc(timeout bool) bool {
sa.collectExpired()
return timeout || len(sa.segReqs) != 0
}
func (sa *segIDAssigner) checkSegReqEqual(req1, req2 *datapb.SegmentIDRequest) bool {
if req1 == nil || req2 == nil {
return false
}
if req1 == req2 {
return true
}
return req1.CollectionID == req2.CollectionID && req1.PartitionID == req2.PartitionID && req1.ChannelName == req2.ChannelName
}
func (sa *segIDAssigner) reduceSegReqs() {
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs", zap.Int("len(segReqs)", len(sa.segReqs)))
if len(sa.segReqs) == 0 {
return
}
beforeCnt := uint32(0)
var newSegReqs []*datapb.SegmentIDRequest
for _, req1 := range sa.segReqs {
if req1.Count == 0 {
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs hit perCount == 0")
req1.Count = sa.countPerRPC
}
beforeCnt += req1.Count
var req2 *datapb.SegmentIDRequest
for _, req3 := range newSegReqs {
if sa.checkSegReqEqual(req1, req3) {
req2 = req3
break
}
}
if req2 == nil { // not found
newSegReqs = append(newSegReqs, req1)
} else {
req2.Count += req1.Count
}
}
afterCnt := uint32(0)
for _, req := range newSegReqs {
afterCnt += req.Count
}
sa.segReqs = newSegReqs
log.Ctx(context.TODO()).Debug("Proxy segIDAssigner reduceSegReqs after reduce", zap.Int("len(segReqs)", len(sa.segReqs)),
zap.Uint32("BeforeCnt", beforeCnt),
zap.Uint32("AfterCnt", afterCnt))
}
func (sa *segIDAssigner) syncSegments() (bool, error) {
if len(sa.segReqs) == 0 {
return true, nil
}
sa.reduceSegReqs()
req := &datapb.AssignSegmentIDRequest{
NodeID: sa.PeerID,
PeerRole: typeutil.ProxyRole,
SegmentIDRequests: sa.segReqs,
}
metrics.ProxySyncSegmentRequestLength.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(len(sa.segReqs)))
sa.segReqs = nil
log.Ctx(context.TODO()).Debug("syncSegments call dataCoord.AssignSegmentID", zap.Stringer("request", req))
resp, err := sa.dataCoord.AssignSegmentID(context.Background(), req)
if err != nil {
return false, fmt.Errorf("syncSegmentID Failed:%w", err)
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return false, fmt.Errorf("syncSegmentID Failed:%s", resp.GetStatus().GetReason())
}
var errMsg string
now := time.Now()
success := true
for _, segAssign := range resp.SegIDAssignments {
if segAssign.Status.GetErrorCode() != commonpb.ErrorCode_Success {
log.Ctx(context.TODO()).Warn("proxy", zap.String("SyncSegment Error", segAssign.GetStatus().GetReason()))
errMsg += segAssign.GetStatus().GetReason()
errMsg += "\n"
success = false
continue
}
assign, err := sa.getAssign(segAssign.CollectionID, segAssign.PartitionID, segAssign.ChannelName)
segInfo2 := &segInfo{
segID: segAssign.SegID,
count: segAssign.Count,
expireTime: segAssign.ExpireTime,
}
if err != nil {
colInfos, ok := sa.assignInfos[segAssign.CollectionID]
if !ok {
colInfos = list.New()
}
segInfos := list.New()
segInfos.PushBack(segInfo2)
assign = &assignInfo{
collID: segAssign.CollectionID,
partitionID: segAssign.PartitionID,
channelName: segAssign.ChannelName,
segInfos: segInfos,
}
colInfos.PushBack(assign)
sa.assignInfos[segAssign.CollectionID] = colInfos
} else {
assign.segInfos.PushBack(segInfo2)
}
assign.lastInsertTime = now
}
if !success {
return false, errors.New(errMsg)
}
return success, nil
}
func (sa *segIDAssigner) processFunc(req allocator.Request) error {
segRequest := req.(*segRequest)
assign, err := sa.getAssign(segRequest.collID, segRequest.partitionID, segRequest.channelName)
if err != nil {
return err
}
result, err2 := assign.Assign(segRequest.timestamp, segRequest.count)
segRequest.segInfo = result
return err2
}
func (sa *segIDAssigner) GetSegmentID(collID UniqueID, partitionID UniqueID, channelName string, count uint32, ts Timestamp) (map[UniqueID]uint32, error) {
req := &segRequest{
BaseRequest: allocator.BaseRequest{Done: make(chan error), Valid: false},
collID: collID,
partitionID: partitionID,
channelName: channelName,
count: count,
timestamp: ts,
}
sa.Reqs <- req
if err := req.Wait(); err != nil {
return nil, fmt.Errorf("getSegmentID failed: %s", err)
}
return req.segInfo, nil
}

View File

@ -1,307 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"math/rand"
"sync"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type mockDataCoord struct {
expireTime Timestamp
}
func (mockD *mockDataCoord) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests))
maxPerCnt := 100
for _, r := range req.SegmentIDRequests {
totalCnt := uint32(0)
for totalCnt != r.Count {
cnt := uint32(rand.Intn(maxPerCnt))
if totalCnt+cnt > r.Count {
cnt = r.Count - totalCnt
}
totalCnt += cnt
result := &datapb.SegmentIDAssignment{
SegID: 1,
ChannelName: r.ChannelName,
Count: cnt,
CollectionID: r.CollectionID,
PartitionID: r.PartitionID,
ExpireTime: mockD.expireTime,
Status: merr.Success(),
}
assigns = append(assigns, result)
}
}
return &datapb.AssignSegmentIDResponse{
Status: merr.Success(),
SegIDAssignments: assigns,
}, nil
}
type mockDataCoord2 struct {
expireTime Timestamp
}
func (mockD *mockDataCoord2) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
return &datapb.AssignSegmentIDResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Just For Test",
},
}, nil
}
func getLastTick1() Timestamp {
return 1000
}
func TestSegmentAllocator1(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(1000)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick1)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
total := uint32(0)
collNames := []string{"abc", "cba"}
for i := 0; i < 10; i++ {
colName := collNames[i%2]
ret, err := segAllocator.GetSegmentID(1, 1, colName, 1, 1)
assert.NoError(t, err)
total += ret[1]
}
assert.Equal(t, uint32(10), total)
ret, err := segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, 999)
assert.NoError(t, err)
assert.Equal(t, uint32(segCountPerRPC-10), ret[1])
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 1001)
assert.Error(t, err)
wg.Wait()
}
var curLastTick2 = Timestamp(200)
var curLastTIck2Lock sync.Mutex
func getLastTick2() Timestamp {
curLastTIck2Lock.Lock()
defer curLastTIck2Lock.Unlock()
curLastTick2 += 100
return curLastTick2
}
func TestSegmentAllocator2(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
segAllocator.Start()
defer segAllocator.Close()
total := uint32(0)
for i := 0; i < 10; i++ {
ret, err := segAllocator.GetSegmentID(1, 1, "abc", 1, 200)
assert.NoError(t, err)
total += ret[1]
}
assert.Equal(t, uint32(10), total)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", segCountPerRPC-10, getLastTick2())
assert.Error(t, err)
}
func TestSegmentAllocator3(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord2{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
assert.Error(t, err)
wg.Wait()
}
type mockDataCoord3 struct {
expireTime Timestamp
}
func (mockD *mockDataCoord3) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests))
for i, r := range req.SegmentIDRequests {
errCode := commonpb.ErrorCode_Success
reason := ""
if i == 0 {
errCode = commonpb.ErrorCode_UnexpectedError
reason = "Just for test"
}
result := &datapb.SegmentIDAssignment{
SegID: 1,
ChannelName: r.ChannelName,
Count: r.Count,
CollectionID: r.CollectionID,
PartitionID: r.PartitionID,
ExpireTime: mockD.expireTime,
Status: &commonpb.Status{
ErrorCode: errCode,
Reason: reason,
},
}
assigns = append(assigns, result)
}
return &datapb.AssignSegmentIDResponse{
Status: merr.Success(),
SegIDAssignments: assigns,
}, nil
}
func TestSegmentAllocator4(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord3{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
assert.Error(t, err)
wg.Wait()
}
type mockDataCoord5 struct {
expireTime Timestamp
}
func (mockD *mockDataCoord5) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) {
return &datapb.AssignSegmentIDResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Just For Test",
},
}, errors.New("just for test")
}
func TestSegmentAllocator5(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord5{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
time.Sleep(50 * time.Millisecond)
_, err = segAllocator.GetSegmentID(1, 1, "abc", 10, 100)
assert.Error(t, err)
wg.Wait()
}
func TestSegmentAllocator6(t *testing.T) {
ctx := context.Background()
dataCoord := &mockDataCoord{}
dataCoord.expireTime = Timestamp(500)
segAllocator, err := newSegIDAssigner(ctx, dataCoord, getLastTick2)
assert.NoError(t, err)
wg := &sync.WaitGroup{}
segAllocator.Start()
wg.Add(1)
go func(group *sync.WaitGroup) {
defer group.Done()
time.Sleep(100 * time.Millisecond)
segAllocator.Close()
}(wg)
success := true
var sucLock sync.Mutex
collNames := []string{"abc", "cba"}
reqFunc := func(i int, group *sync.WaitGroup) {
defer group.Done()
sucLock.Lock()
defer sucLock.Unlock()
if !success {
return
}
colName := collNames[i%2]
count := uint32(10)
if i == 0 {
count = 0
}
_, err = segAllocator.GetSegmentID(1, 1, colName, count, 100)
if err != nil {
t.Log(err)
success = false
}
}
for i := 0; i < 10; i++ {
wg.Add(1)
go reqFunc(i, wg)
}
wg.Wait()
assert.True(t, success)
}

View File

@ -587,7 +587,6 @@ type dropCollectionTask struct {
mixCoord types.MixCoordClient
result *commonpb.Status
chMgr channelsMgr
chTicker channelsTimeTicker
}
func (t *dropCollectionTask) TraceCtx() context.Context {
@ -1047,10 +1046,9 @@ type alterCollectionTask struct {
baseTask
Condition
*milvuspb.AlterCollectionRequest
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
replicateMsgStream msgstream.MsgStream
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
}
func (t *alterCollectionTask) TraceCtx() context.Context {
@ -1266,7 +1264,6 @@ func (t *alterCollectionTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterCollectionRequest)
return nil
}
@ -1278,10 +1275,9 @@ type alterCollectionFieldTask struct {
baseTask
Condition
*milvuspb.AlterCollectionFieldRequest
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
replicateMsgStream msgstream.MsgStream
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
}
func (t *alterCollectionFieldTask) TraceCtx() context.Context {
@ -1495,7 +1491,6 @@ func (t *alterCollectionFieldTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterCollectionFieldRequest)
return nil
}
@ -1938,8 +1933,7 @@ type loadCollectionTask struct {
mixCoord types.MixCoordClient
result *commonpb.Status
collectionID UniqueID
replicateMsgStream msgstream.MsgStream
collectionID UniqueID
}
func (t *loadCollectionTask) TraceCtx() context.Context {
@ -2088,7 +2082,6 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return fmt.Errorf("call query coordinator LoadCollection: %s", err)
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadCollectionRequest)
return nil
}
@ -2111,8 +2104,7 @@ type releaseCollectionTask struct {
mixCoord types.MixCoordClient
result *commonpb.Status
collectionID UniqueID
replicateMsgStream msgstream.MsgStream
collectionID UniqueID
}
func (t *releaseCollectionTask) TraceCtx() context.Context {
@ -2186,7 +2178,6 @@ func (t *releaseCollectionTask) Execute(ctx context.Context) (err error) {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleaseCollectionRequest)
return nil
}
@ -2202,8 +2193,7 @@ type loadPartitionsTask struct {
mixCoord types.MixCoordClient
result *commonpb.Status
collectionID UniqueID
replicateMsgStream msgstream.MsgStream
collectionID UniqueID
}
func (t *loadPartitionsTask) TraceCtx() context.Context {
@ -2362,7 +2352,6 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadPartitionsRequest)
return nil
}
@ -2379,8 +2368,7 @@ type releasePartitionsTask struct {
mixCoord types.MixCoordClient
result *commonpb.Status
collectionID UniqueID
replicateMsgStream msgstream.MsgStream
collectionID UniqueID
}
func (t *releasePartitionsTask) TraceCtx() context.Context {
@ -2469,7 +2457,6 @@ func (t *releasePartitionsTask) Execute(ctx context.Context) (err error) {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleasePartitionsRequest)
return nil
}

View File

@ -22,7 +22,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -33,10 +32,9 @@ type CreateAliasTask struct {
baseTask
Condition
*milvuspb.CreateAliasRequest
ctx context.Context
mixCoord types.MixCoordClient
replicateMsgStream msgstream.MsgStream
result *commonpb.Status
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
}
// TraceCtx returns the trace context of the task.
@ -111,7 +109,6 @@ func (t *CreateAliasTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.CreateAliasRequest)
return nil
}
@ -125,10 +122,9 @@ type DropAliasTask struct {
baseTask
Condition
*milvuspb.DropAliasRequest
ctx context.Context
mixCoord types.MixCoordClient
replicateMsgStream msgstream.MsgStream
result *commonpb.Status
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
}
// TraceCtx returns the context for trace
@ -189,7 +185,6 @@ func (t *DropAliasTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.DropAliasRequest)
return nil
}
@ -202,10 +197,9 @@ type AlterAliasTask struct {
baseTask
Condition
*milvuspb.AlterAliasRequest
ctx context.Context
mixCoord types.MixCoordClient
replicateMsgStream msgstream.MsgStream
result *commonpb.Status
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
}
func (t *AlterAliasTask) TraceCtx() context.Context {
@ -270,7 +264,6 @@ func (t *AlterAliasTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterAliasRequest)
return nil
}

View File

@ -11,7 +11,6 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -25,8 +24,6 @@ type createDatabaseTask struct {
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
replicateMsgStream msgstream.MsgStream
}
func (cdt *createDatabaseTask) TraceCtx() context.Context {
@ -78,9 +75,6 @@ func (cdt *createDatabaseTask) Execute(ctx context.Context) error {
var err error
cdt.result, err = cdt.mixCoord.CreateDatabase(ctx, cdt.CreateDatabaseRequest)
err = merr.CheckRPCCall(cdt.result, err)
if err == nil {
SendReplicateMessagePack(ctx, cdt.replicateMsgStream, cdt.CreateDatabaseRequest)
}
return err
}
@ -95,8 +89,6 @@ type dropDatabaseTask struct {
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
replicateMsgStream msgstream.MsgStream
}
func (ddt *dropDatabaseTask) TraceCtx() context.Context {
@ -151,7 +143,6 @@ func (ddt *dropDatabaseTask) Execute(ctx context.Context) error {
err = merr.CheckRPCCall(ddt.result, err)
if err == nil {
globalMetaCache.RemoveDatabase(ctx, ddt.DbName)
SendReplicateMessagePack(ctx, ddt.replicateMsgStream, ddt.DropDatabaseRequest)
}
return err
}
@ -230,8 +221,6 @@ type alterDatabaseTask struct {
ctx context.Context
mixCoord types.MixCoordClient
result *commonpb.Status
replicateMsgStream msgstream.MsgStream
}
func (t *alterDatabaseTask) TraceCtx() context.Context {
@ -323,8 +312,6 @@ func (t *alterDatabaseTask) Execute(ctx context.Context) error {
if err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.AlterDatabaseRequest)
t.result = ret
return nil
}

View File

@ -2,13 +2,11 @@ package proxy
import (
"context"
"fmt"
"io"
"strconv"
"time"
"github.com/cockroachdb/errors"
"go.opentelemetry.io/otel"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
@ -21,7 +19,6 @@ import (
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
@ -133,60 +130,6 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
return nil
}
func (dt *deleteTask) Execute(ctx context.Context) (err error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute")
defer sp.End()
if len(dt.req.GetExpr()) == 0 {
return merr.WrapErrParameterInvalid("valid expr", "empty expr", "invalid expression")
}
dt.tr = timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
stream, err := dt.chMgr.getOrCreateDmlStream(ctx, dt.collectionID)
if err != nil {
return err
}
result, numRows, err := repackDeleteMsgByHash(
ctx,
dt.primaryKeys, dt.vChannels,
dt.idAllocator, dt.ts,
dt.collectionID, dt.req.GetCollectionName(),
dt.partitionID, dt.req.GetPartitionName(),
dt.req.GetDbName(),
)
if err != nil {
return err
}
// send delete request to log broker
msgPack := &msgstream.MsgPack{
BeginTs: dt.BeginTs(),
EndTs: dt.EndTs(),
}
for _, msgs := range result {
for _, msg := range msgs {
msgPack.Msgs = append(msgPack.Msgs, msg)
}
}
log.Ctx(ctx).Debug("send delete request to virtual channels",
zap.String("collectionName", dt.req.GetCollectionName()),
zap.Int64("collectionID", dt.collectionID),
zap.Strings("virtual_channels", dt.vChannels),
zap.Int64("taskID", dt.ID()),
zap.Duration("prepare duration", dt.tr.RecordSpan()))
err = stream.Produce(ctx, msgPack)
if err != nil {
return err
}
dt.sessionTS = dt.ts
dt.count += numRows
return nil
}
func (dt *deleteTask) PostExecute(ctx context.Context) error {
metrics.ProxyDeleteVectors.WithLabelValues(
paramtable.GetStringNodeID(),
@ -288,7 +231,6 @@ type deleteRunner struct {
// channel
chMgr channelsMgr
chTicker channelsTimeTicker
vChannels []vChan
idAllocator allocator.Interface
@ -437,20 +379,13 @@ func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs,
req: dr.req,
idAllocator: dr.idAllocator,
chMgr: dr.chMgr,
chTicker: dr.chTicker,
collectionID: dr.collectionID,
partitionID: partitionID,
vChannels: dr.vChannels,
primaryKeys: primaryKeys,
dbID: dr.dbID,
}
var enqueuedTask task = dt
if streamingutil.IsStreamingServiceEnabled() {
enqueuedTask = &deleteTaskByStreamingService{deleteTask: dt}
}
if err := dr.queue.Enqueue(enqueuedTask); err != nil {
if err := dr.queue.Enqueue(dt); err != nil {
log.Ctx(ctx).Error("Failed to enqueue delete task: " + err.Error())
return nil, err
}

View File

@ -15,13 +15,9 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type deleteTaskByStreamingService struct {
*deleteTask
}
// Execute is a function to delete task by streaming service
// we only overwrite the Execute function
func (dt *deleteTaskByStreamingService) Execute(ctx context.Context) (err error) {
func (dt *deleteTask) Execute(ctx context.Context) (err error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Delete-Execute")
defer sp.End()

View File

@ -14,11 +14,11 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
@ -152,19 +152,6 @@ func TestDeleteTask_Execute(t *testing.T) {
assert.Error(t, dt.Execute(context.Background()))
})
t.Run("get channel failed", func(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
dt := deleteTask{
chMgr: mockMgr,
req: &milvuspb.DeleteRequest{
Expr: "pk in [1,2]",
},
}
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background()))
})
t.Run("alloc failed", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -189,9 +176,6 @@ func TestDeleteTask_Execute(t *testing.T) {
},
primaryKeys: pk,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
assert.Error(t, dt.Execute(context.Background()))
})
@ -225,9 +209,7 @@ func TestDeleteTask_Execute(t *testing.T) {
},
primaryKeys: pk,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
streaming.ExpectErrorOnce(errors.New("mock error"))
assert.Error(t, dt.Execute(context.Background()))
})
}
@ -666,7 +648,7 @@ func TestDeleteRunner_Run(t *testing.T) {
tsoAllocator := &mockTsoAllocator{}
idAllocator := &mockIDAllocatorInterface{}
queue, err := newTaskScheduler(ctx, tsoAllocator, nil)
queue, err := newTaskScheduler(ctx, tsoAllocator)
assert.NoError(t, err)
queue.Start()
defer queue.Close()
@ -731,11 +713,8 @@ func TestDeleteRunner_Run(t *testing.T) {
},
plan: plan,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
streaming.ExpectErrorOnce(errors.New("mock error"))
assert.Error(t, dr.Run(context.Background()))
assert.Equal(t, int64(0), dr.result.DeleteCnt)
})
@ -814,10 +793,7 @@ func TestDeleteRunner_Run(t *testing.T) {
},
plan: plan,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -946,8 +922,6 @@ func TestDeleteRunner_Run(t *testing.T) {
},
plan: plan,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -971,8 +945,8 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil)
return client
}, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("mock error"))
streaming.ExpectErrorOnce(errors.New("mock error"))
assert.Error(t, dr.Run(ctx))
assert.Equal(t, int64(0), dr.result.DeleteCnt)
})
@ -1012,8 +986,6 @@ func TestDeleteRunner_Run(t *testing.T) {
},
plan: plan,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -1037,8 +1009,6 @@ func TestDeleteRunner_Run(t *testing.T) {
server.FinishSend(nil)
return client
}, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt)
})
@ -1088,8 +1058,6 @@ func TestDeleteRunner_Run(t *testing.T) {
},
plan: plan,
}
stream := msgstream.NewMockMsgStream(t)
mockMgr.EXPECT().getOrCreateDmlStream(mock.Anything, mock.Anything).Return(stream, nil)
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
@ -1114,7 +1082,6 @@ func TestDeleteRunner_Run(t *testing.T) {
return client
}, nil)
stream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
assert.NoError(t, dr.Run(ctx))
assert.Equal(t, int64(3), dr.result.DeleteCnt)
})

View File

@ -18,17 +18,11 @@ package proxy
import (
"context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
@ -40,7 +34,7 @@ type flushTask struct {
mixCoord types.MixCoordClient
result *milvuspb.FlushResponse
replicateMsgStream msgstream.MsgStream
chMgr channelsMgr
}
func (t *flushTask) TraceCtx() context.Context {
@ -88,46 +82,6 @@ func (t *flushTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *flushTask) Execute(ctx context.Context) error {
coll2Segments := make(map[string]*schemapb.LongArray)
flushColl2Segments := make(map[string]*schemapb.LongArray)
coll2SealTimes := make(map[string]int64)
coll2FlushTs := make(map[string]Timestamp)
channelCps := make(map[string]*msgpb.MsgPosition)
for _, collName := range t.CollectionNames {
collID, err := globalMetaCache.GetCollectionID(ctx, t.GetDbName(), collName)
if err != nil {
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
flushReq := &datapb.FlushRequest{
Base: commonpbutil.UpdateMsgBase(
t.Base,
commonpbutil.WithMsgType(commonpb.MsgType_Flush),
),
CollectionID: collID,
}
resp, err := t.mixCoord.Flush(ctx, flushReq)
if err = merr.CheckRPCCall(resp, err); err != nil {
return fmt.Errorf("failed to call flush to data coordinator: %s", err.Error())
}
coll2Segments[collName] = &schemapb.LongArray{Data: resp.GetSegmentIDs()}
flushColl2Segments[collName] = &schemapb.LongArray{Data: resp.GetFlushSegmentIDs()}
coll2SealTimes[collName] = resp.GetTimeOfSeal()
coll2FlushTs[collName] = resp.GetFlushTs()
channelCps = resp.GetChannelCps()
}
t.result = &milvuspb.FlushResponse{
Status: merr.Success(),
DbName: t.GetDbName(),
CollSegIDs: coll2Segments,
FlushCollSegIDs: flushColl2Segments,
CollSealTimes: coll2SealTimes,
CollFlushTs: coll2FlushTs,
ChannelCps: channelCps,
}
return nil
}
func (t *flushTask) PostExecute(ctx context.Context) error {
return nil
}

View File

@ -18,15 +18,12 @@ package proxy
import (
"context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
@ -37,8 +34,7 @@ type flushAllTask struct {
ctx context.Context
mixCoord types.MixCoordClient
result *datapb.FlushAllResponse
replicateMsgStream msgstream.MsgStream
chMgr channelsMgr
}
func (t *flushAllTask) TraceCtx() context.Context {
@ -86,22 +82,6 @@ func (t *flushAllTask) PreExecute(ctx context.Context) error {
return nil
}
func (t *flushAllTask) Execute(ctx context.Context) error {
flushAllReq := &datapb.FlushAllRequest{
Base: commonpbutil.UpdateMsgBase(
t.Base,
commonpbutil.WithMsgType(commonpb.MsgType_Flush),
),
}
resp, err := t.mixCoord.FlushAll(ctx, flushAllReq)
if err = merr.CheckRPCCall(resp, err); err != nil {
return fmt.Errorf("failed to call flush all to data coordinator: %s", err.Error())
}
t.result = resp
return nil
}
func (t *flushAllTask) PostExecute(ctx context.Context) error {
return nil
}

View File

@ -30,12 +30,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type flushAllTaskbyStreamingService struct {
*flushAllTask
chMgr channelsMgr
}
func (t *flushAllTaskbyStreamingService) Execute(ctx context.Context) error {
func (t *flushAllTask) Execute(ctx context.Context) error {
dbNames := make([]string, 0)
if t.GetDbName() != "" {
dbNames = append(dbNames, t.GetDbName())
@ -60,7 +55,7 @@ func (t *flushAllTaskbyStreamingService) Execute(ctx context.Context) error {
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections)),
DbName: dbName,
})
if err != nil {
if err := merr.CheckRPCCall(showColRsp, err); err != nil {
log.Info("flush all task by streaming service failed, show collections failed", zap.String("dbName", dbName), zap.Error(err))
return err
}

View File

@ -33,7 +33,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flushAllTaskbyStreamingService, *mocks.MockMixCoordClient, *msgstream.MockMsgStream, *MockChannelsMgr, context.Context) {
func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flushAllTask, *mocks.MockMixCoordClient, *msgstream.MockMsgStream, *MockChannelsMgr, context.Context) {
ctx := context.Background()
mixCoord := mocks.NewMockMixCoordClient(t)
replicateMsgStream := msgstream.NewMockMsgStream(t)
@ -51,17 +51,11 @@ func createTestFlushAllTaskByStreamingService(t *testing.T, dbName string) (*flu
},
DbName: dbName,
},
ctx: ctx,
mixCoord: mixCoord,
replicateMsgStream: replicateMsgStream,
ctx: ctx,
mixCoord: mixCoord,
chMgr: chMgr,
}
task := &flushAllTaskbyStreamingService{
flushAllTask: baseTask,
chMgr: chMgr,
}
return task, mixCoord, replicateMsgStream, chMgr, ctx
return baseTask, mixCoord, replicateMsgStream, chMgr, ctx
}
func TestFlushAllTask_WithSpecificDB(t *testing.T) {

View File

@ -18,19 +18,15 @@ package proxy
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/uniquegenerator"
)
@ -50,9 +46,8 @@ func createTestFlushAllTask(t *testing.T) (*flushAllTask, *mocks.MockMixCoordCli
SourceID: 1,
},
},
ctx: ctx,
mixCoord: mixCoord,
replicateMsgStream: replicateMsgStream,
ctx: ctx,
mixCoord: mixCoord,
}
return task, mixCoord, replicateMsgStream, ctx
@ -151,95 +146,6 @@ func TestFlushAllTaskPreExecute(t *testing.T) {
assert.NoError(t, err)
}
func TestFlushAllTaskExecuteSuccess(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Setup expectations
expectedResp := &datapb.FlushAllResponse{
Status: merr.Success(),
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(expectedResp, nil).Once()
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, expectedResp, task.result)
}
func TestFlushAllTaskExecuteFlushAllRPCError(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test RPC call error
expectedErr := fmt.Errorf("rpc error")
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(nil, expectedErr).Once()
err := task.Execute(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to call flush all to data coordinator")
}
func TestFlushAllTaskExecuteFlushAllResponseError(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test response with error status
errorResp := &datapb.FlushAllResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "flush all failed",
},
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(errorResp, nil).Once()
err := task.Execute(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to call flush all to data coordinator")
}
func TestFlushAllTaskExecuteWithMerCheck(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test successful execution with merr.CheckRPCCall
successResp := &datapb.FlushAllResponse{
Status: merr.Success(),
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(successResp, nil).Once()
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, successResp, task.result)
}
func TestFlushAllTaskExecuteRequestContent(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test the content of the FlushAllRequest sent to mixCoord
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(&datapb.FlushAllResponse{Status: merr.Success()}, nil).Once()
err := task.Execute(ctx)
assert.NoError(t, err)
// The test verifies that Execute method creates the correct request structure internally
// The actual request content validation is covered by other tests
}
func TestFlushAllTaskPostExecute(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
@ -249,97 +155,6 @@ func TestFlushAllTaskPostExecute(t *testing.T) {
assert.NoError(t, err)
}
func TestFlushAllTaskLifecycle(t *testing.T) {
ctx := context.Background()
mixCoord := mocks.NewMockMixCoordClient(t)
replicateMsgStream := msgstream.NewMockMsgStream(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
// Test complete task lifecycle
// 1. OnEnqueue
task := &flushAllTask{
baseTask: baseTask{},
Condition: NewTaskCondition(ctx),
FlushAllRequest: &milvuspb.FlushAllRequest{},
ctx: ctx,
mixCoord: mixCoord,
replicateMsgStream: replicateMsgStream,
}
err := task.OnEnqueue()
assert.NoError(t, err)
// 2. PreExecute
err = task.PreExecute(ctx)
assert.NoError(t, err)
// 3. Execute
expectedResp := &datapb.FlushAllResponse{
Status: merr.Success(),
}
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(expectedResp, nil).Once()
err = task.Execute(ctx)
assert.NoError(t, err)
// 4. PostExecute
err = task.PostExecute(ctx)
assert.NoError(t, err)
// Verify task state
assert.Equal(t, expectedResp, task.result)
}
func TestFlushAllTaskErrorHandlingInExecute(t *testing.T) {
// Test different error scenarios in Execute method
testCases := []struct {
name string
setupMock func(*mocks.MockMixCoordClient)
expectedError string
}{
{
name: "mixCoord FlushAll returns error",
setupMock: func(mixCoord *mocks.MockMixCoordClient) {
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(nil, fmt.Errorf("network error")).Once()
},
expectedError: "failed to call flush all to data coordinator",
},
{
name: "mixCoord FlushAll returns error status",
setupMock: func(mixCoord *mocks.MockMixCoordClient) {
mixCoord.EXPECT().FlushAll(mock.Anything, mock.AnythingOfType("*datapb.FlushAllRequest")).
Return(&datapb.FlushAllResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_IllegalArgument,
Reason: "invalid request",
},
}, nil).Once()
},
expectedError: "failed to call flush all to data coordinator",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
task, mixCoord, replicateMsgStream, ctx := createTestFlushAllTask(t)
defer mixCoord.AssertExpectations(t)
defer replicateMsgStream.AssertExpectations(t)
tc.setupMock(mixCoord)
err := task.Execute(ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
})
}
}
func TestFlushAllTaskImplementsTaskInterface(t *testing.T) {
// Verify that flushAllTask implements the task interface
var _ task = (*flushAllTask)(nil)

View File

@ -35,12 +35,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
)
type flushTaskByStreamingService struct {
*flushTask
chMgr channelsMgr
}
func (t *flushTaskByStreamingService) Execute(ctx context.Context) error {
func (t *flushTask) Execute(ctx context.Context) error {
coll2Segments := make(map[string]*schemapb.LongArray)
flushColl2Segments := make(map[string]*schemapb.LongArray)
coll2SealTimes := make(map[string]int64)

View File

@ -32,7 +32,6 @@ import (
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -64,8 +63,6 @@ type createIndexTask struct {
mixCoord types.MixCoordClient
result *commonpb.Status
replicateMsgStream msgstream.MsgStream
isAutoIndex bool
newIndexParams []*commonpb.KeyValuePair
newTypeParams []*commonpb.KeyValuePair
@ -580,7 +577,6 @@ func (cit *createIndexTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(cit.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, cit.replicateMsgStream, cit.req)
return nil
}
@ -596,8 +592,6 @@ type alterIndexTask struct {
mixCoord types.MixCoordClient
result *commonpb.Status
replicateMsgStream msgstream.MsgStream
collectionID UniqueID
}
@ -711,7 +705,6 @@ func (t *alterIndexTask) Execute(ctx context.Context) error {
if err = merr.CheckRPCCall(t.result, err); err != nil {
return err
}
SendReplicateMessagePack(ctx, t.replicateMsgStream, t.req)
return nil
}
@ -978,8 +971,6 @@ type dropIndexTask struct {
result *commonpb.Status
collectionID UniqueID
replicateMsgStream msgstream.MsgStream
}
func (dit *dropIndexTask) TraceCtx() context.Context {
@ -1073,7 +1064,6 @@ func (dit *dropIndexTask) Execute(ctx context.Context) error {
ctxLog.Warn("drop index failed", zap.Error(err))
return err
}
SendReplicateMessagePack(ctx, dit.replicateMsgStream, dit.DropIndexRequest)
return nil
}

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
@ -46,6 +47,7 @@ import (
func TestMain(m *testing.M) {
paramtable.Init()
gin.SetMode(gin.TestMode)
streaming.SetupNoopWALForTest()
code := m.Run()
os.Exit(code)
}

View File

@ -2,7 +2,6 @@ package proxy
import (
"context"
"fmt"
"strconv"
"go.opentelemetry.io/otel"
@ -15,7 +14,6 @@ import (
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -32,9 +30,7 @@ type insertTask struct {
result *milvuspb.MutationResult
idAllocator *allocator.IDAllocator
segIDAssigner *segIDAssigner
chMgr channelsMgr
chTicker channelsTimeTicker
vChannels []vChan
pChannels []pChan
schema *schemapb.CollectionSchema
@ -292,76 +288,6 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return nil
}
func (it *insertTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute")
defer sp.End()
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID()))
collectionName := it.insertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.GetDbName(), collectionName)
log := log.Ctx(ctx)
if err != nil {
log.Warn("fail to get collection id", zap.Error(err))
return err
}
it.insertMsg.CollectionID = collID
getCacheDur := tr.RecordSpan()
stream, err := it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil {
return err
}
getMsgStreamDur := tr.RecordSpan()
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Warn("get vChannels failed", zap.Int64("collectionID", collID), zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
log.Debug("send insert request to virtual channels",
zap.String("partition", it.insertMsg.GetPartitionName()),
zap.Int64("collectionID", collID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", it.ID()),
zap.Duration("get cache duration", getCacheDur),
zap.Duration("get msgStream duration", getMsgStreamDur))
// assign segmentID for insert data and repack data by segmentID
var msgPack *msgstream.MsgPack
if it.partitionKeys == nil {
msgPack, err = repackInsertData(it.TraceCtx(), channelNames, it.insertMsg, it.result, it.idAllocator, it.segIDAssigner)
} else {
msgPack, err = repackInsertDataWithPartitionKey(it.TraceCtx(), channelNames, it.partitionKeys, it.insertMsg, it.result, it.idAllocator, it.segIDAssigner)
}
if err != nil {
log.Warn("assign segmentID and repack insert data failed", zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
assignSegmentIDDur := tr.RecordSpan()
log.Debug("assign segmentID for insert data success",
zap.Duration("assign segmentID duration", assignSegmentIDDur))
err = stream.Produce(ctx, msgPack)
if err != nil {
log.Warn("fail to produce insert msg", zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
sendMsgDur := tr.RecordSpan()
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
totalExecDur := tr.ElapseSpan()
log.Debug("Proxy Insert Execute done",
zap.String("collectionName", collectionName),
zap.Duration("send message duration", sendMsgDur),
zap.Duration("execute duration", totalExecDur))
return nil
}
func (it *insertTask) PostExecute(ctx context.Context) error {
return nil
}

View File

@ -18,12 +18,8 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type insertTaskByStreamingService struct {
*insertTask
}
// we only overwrite the Execute function
func (it *insertTaskByStreamingService) Execute(ctx context.Context) error {
func (it *insertTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Insert-Execute")
defer sp.End()

View File

@ -29,7 +29,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
@ -437,22 +436,18 @@ type taskScheduler struct {
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
msFactory msgstream.Factory
}
type schedOpt func(*taskScheduler)
func newTaskScheduler(ctx context.Context,
tsoAllocatorIns tsoAllocator,
factory msgstream.Factory,
opts ...schedOpt,
) (*taskScheduler, error) {
ctx1, cancel := context.WithCancel(ctx)
s := &taskScheduler{
ctx: ctx1,
cancel: cancel,
msFactory: factory,
ctx: ctx1,
cancel: cancel,
}
s.ddQueue = newDdTaskQueue(tsoAllocatorIns)
s.dmQueue = newDmTaskQueue(tsoAllocatorIns)

View File

@ -492,9 +492,8 @@ func TestTaskScheduler(t *testing.T) {
ctx := context.Background()
tsoAllocatorIns := newMockTsoAllocator()
factory := newSimpleMockMsgStreamFactory()
sched, err := newTaskScheduler(ctx, tsoAllocatorIns, factory)
sched, err := newTaskScheduler(ctx, tsoAllocatorIns)
assert.NoError(t, err)
assert.NotNil(t, sched)
@ -572,8 +571,7 @@ func TestTaskScheduler_concurrentPushAndPop(t *testing.T) {
).Return(collectionID, nil)
globalMetaCache = cache
tsoAllocatorIns := newMockTsoAllocator()
factory := newSimpleMockMsgStreamFactory()
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns, factory)
scheduler, err := newTaskScheduler(context.Background(), tsoAllocatorIns)
assert.NoError(t, err)
run := func(wg *sync.WaitGroup) {

View File

@ -3586,7 +3586,7 @@ func TestSearchTask_Requery(t *testing.T) {
node.tsoAllocator = &timestampAllocator{
tso: newMockTimestampAllocatorInterface(),
}
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator, factory)
scheduler, err := newTaskScheduler(ctx, node.tsoAllocator)
assert.NoError(t, err)
node.sched = scheduler
err = node.sched.Start()

View File

@ -2161,12 +2161,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
@ -2182,11 +2177,6 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("insert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb)
task := &insertTask{
@ -2221,13 +2211,11 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
idAllocator: idAllocator,
chMgr: chMgr,
vChannels: nil,
pChannels: nil,
schema: nil,
}
for fieldName, dataType := range fieldName2Types {
@ -2403,12 +2391,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, mixc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
@ -2424,12 +2407,6 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("insert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb)
task := &insertTask{
@ -2464,13 +2441,11 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
idAllocator: idAllocator,
chMgr: chMgr,
vChannels: nil,
pChannels: nil,
schema: nil,
}
fieldID := common.StartOfUserFieldID
@ -2549,13 +2524,12 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
idAllocator: idAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
}
fieldID := common.StartOfUserFieldID
@ -3817,12 +3791,7 @@ func TestPartitionKey(t *testing.T) {
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
@ -3838,12 +3807,6 @@ func TestPartitionKey(t *testing.T) {
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", collectionName)
assert.NoError(t, err)
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames)))
@ -3888,13 +3851,11 @@ func TestPartitionKey(t *testing.T) {
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
idAllocator: idAllocator,
chMgr: chMgr,
vChannels: nil,
pChannels: nil,
schema: nil,
}
// don't support specify partition name if use partition key
@ -3932,10 +3893,9 @@ func TestPartitionKey(t *testing.T) {
IdField: nil,
},
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
idAllocator: idAllocator,
chMgr: chMgr,
chTicker: ticker,
}
// don't support specify partition name if use partition key
@ -4060,12 +4020,8 @@ func TestDefaultPartition(t *testing.T) {
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, qc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil)
_, err = chMgr.getOrCreateDmlStream(ctx, collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
@ -4081,12 +4037,6 @@ func TestDefaultPartition(t *testing.T) {
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
nb := 10
fieldID := common.StartOfUserFieldID
fieldDatas := make([]*schemapb.FieldData, 0)
@ -4127,13 +4077,11 @@ func TestDefaultPartition(t *testing.T) {
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
idAllocator: idAllocator,
chMgr: chMgr,
vChannels: nil,
pChannels: nil,
schema: nil,
}
it.insertMsg.PartitionName = ""
@ -4167,10 +4115,9 @@ func TestDefaultPartition(t *testing.T) {
IdField: nil,
},
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
idAllocator: idAllocator,
chMgr: chMgr,
chTicker: ticker,
}
ut.req.PartitionName = ""

View File

@ -17,7 +17,6 @@ package proxy
import (
"context"
"fmt"
"strconv"
"github.com/cockroachdb/errors"
@ -55,7 +54,6 @@ type upsertTask struct {
rowIDs []int64
result *milvuspb.MutationResult
idAllocator *allocator.IDAllocator
segIDAssigner *segIDAssigner
collectionID UniqueID
chMgr channelsMgr
chTicker channelsTimeTicker
@ -445,189 +443,6 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
return nil
}
func (it *upsertTask) insertExecute(ctx context.Context, msgPack *msgstream.MsgPack) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", it.ID()))
defer tr.Elapse("insert execute done when insertExecute")
collectionName := it.upsertMsg.InsertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collectionName)
if err != nil {
return err
}
it.upsertMsg.InsertMsg.CollectionID = collID
it.upsertMsg.InsertMsg.BeginTimestamp = it.BeginTs()
it.upsertMsg.InsertMsg.EndTimestamp = it.EndTs()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collID))
getCacheDur := tr.RecordSpan()
_, err = it.chMgr.getOrCreateDmlStream(ctx, collID)
if err != nil {
return err
}
getMsgStreamDur := tr.RecordSpan()
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Warn("get vChannels failed when insertExecute",
zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
log.Debug("send insert request to virtual channels when insertExecute",
zap.String("collection", it.req.GetCollectionName()),
zap.String("partition", it.req.GetPartitionName()),
zap.Int64("collection_id", collID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", it.ID()),
zap.Duration("get cache duration", getCacheDur),
zap.Duration("get msgStream duration", getMsgStreamDur))
// assign segmentID for insert data and repack data by segmentID
var insertMsgPack *msgstream.MsgPack
if it.partitionKeys == nil {
insertMsgPack, err = repackInsertData(it.TraceCtx(), channelNames, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner)
} else {
insertMsgPack, err = repackInsertDataWithPartitionKey(it.TraceCtx(), channelNames, it.partitionKeys, it.upsertMsg.InsertMsg, it.result, it.idAllocator, it.segIDAssigner)
}
if err != nil {
log.Warn("assign segmentID and repack insert data failed when insertExecute",
zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
assignSegmentIDDur := tr.RecordSpan()
log.Debug("assign segmentID for insert data success when insertExecute",
zap.String("collectionName", it.req.CollectionName),
zap.Duration("assign segmentID duration", assignSegmentIDDur))
msgPack.Msgs = append(msgPack.Msgs, insertMsgPack.Msgs...)
log.Debug("Proxy Insert Execute done when upsert",
zap.String("collectionName", collectionName))
return nil
}
func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgPack) (err error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
collID := it.upsertMsg.DeleteMsg.CollectionID
log := log.Ctx(ctx).With(
zap.Int64("collectionID", collID))
// hash primary keys to channels
channelNames, err := it.chMgr.getVChannels(collID)
if err != nil {
log.Warn("get vChannels failed when deleteExecute", zap.Error(err))
it.result.Status = merr.Status(err)
return err
}
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs
it.upsertMsg.DeleteMsg.HashValues = typeutil.HashPK2Channels(it.upsertMsg.DeleteMsg.PrimaryKeys, channelNames)
// repack delete msg by dmChannel
result := make(map[uint32]msgstream.TsMsg)
collectionName := it.upsertMsg.DeleteMsg.CollectionName
collectionID := it.upsertMsg.DeleteMsg.CollectionID
partitionID := it.upsertMsg.DeleteMsg.PartitionID
partitionName := it.upsertMsg.DeleteMsg.PartitionName
proxyID := it.upsertMsg.DeleteMsg.Base.SourceID
for index, key := range it.upsertMsg.DeleteMsg.HashValues {
ts := it.upsertMsg.DeleteMsg.Timestamps[index]
_, ok := result[key]
if !ok {
msgid, err := it.idAllocator.AllocOne()
if err != nil {
return errors.Wrap(err, "failed to allocate MsgID for delete of upsert")
}
sliceRequest := &msgpb.DeleteRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
commonpbutil.WithTimeStamp(ts),
// id of upsertTask were set as ts in scheduler
// msgid of delete msg must be set
// or it will be seen as duplicated msg in mq
commonpbutil.WithMsgID(msgid),
commonpbutil.WithSourceID(proxyID),
),
CollectionID: collectionID,
PartitionID: partitionID,
CollectionName: collectionName,
PartitionName: partitionName,
PrimaryKeys: &schemapb.IDs{},
}
deleteMsg := &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
},
DeleteRequest: sliceRequest,
}
result[key] = deleteMsg
}
curMsg := result[key].(*msgstream.DeleteMsg)
curMsg.HashValues = append(curMsg.HashValues, it.upsertMsg.DeleteMsg.HashValues[index])
curMsg.Timestamps = append(curMsg.Timestamps, it.upsertMsg.DeleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, it.upsertMsg.DeleteMsg.PrimaryKeys, index)
curMsg.NumRows++
curMsg.ShardName = channelNames[key]
}
// send delete request to log broker
deleteMsgPack := &msgstream.MsgPack{
BeginTs: it.upsertMsg.DeleteMsg.BeginTs(),
EndTs: it.upsertMsg.DeleteMsg.EndTs(),
}
for _, msg := range result {
if msg != nil {
deleteMsgPack.Msgs = append(deleteMsgPack.Msgs, msg)
}
}
msgPack.Msgs = append(msgPack.Msgs, deleteMsgPack.Msgs...)
log.Debug("Proxy Upsert deleteExecute done", zap.Int64("collection_id", collID),
zap.Strings("virtual_channels", channelNames), zap.Int64("taskID", it.ID()),
zap.Duration("prepare duration", tr.ElapseSpan()))
return nil
}
func (it *upsertTask) Execute(ctx context.Context) (err error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute")
defer sp.End()
log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName))
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute upsert %d", it.ID()))
stream, err := it.chMgr.getOrCreateDmlStream(ctx, it.collectionID)
if err != nil {
return err
}
msgPack := &msgstream.MsgPack{
BeginTs: it.BeginTs(),
EndTs: it.EndTs(),
}
err = it.insertExecute(ctx, msgPack)
if err != nil {
log.Warn("Fail to insertExecute", zap.Error(err))
return err
}
err = it.deleteExecute(ctx, msgPack)
if err != nil {
log.Warn("Fail to deleteExecute", zap.Error(err))
return err
}
tr.RecordSpan()
err = stream.Produce(ctx, msgPack)
if err != nil {
it.result.Status = merr.Status(err)
return err
}
sendMsgDur := tr.RecordSpan()
metrics.ProxySendMutationReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.UpsertLabel).Observe(float64(sendMsgDur.Milliseconds()))
totalDur := tr.ElapseSpan()
log.Debug("Proxy Upsert Execute done", zap.Int64("taskID", it.ID()),
zap.Duration("total duration", totalDur))
return nil
}
func (it *upsertTask) PostExecute(ctx context.Context) error {
return nil
}

View File

@ -15,11 +15,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type upsertTaskByStreamingService struct {
*upsertTask
}
func (ut *upsertTaskByStreamingService) Execute(ctx context.Context) error {
func (ut *upsertTask) Execute(ctx context.Context) error {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-Execute")
defer sp.End()
log := log.Ctx(ctx).With(zap.String("collectionName", ut.req.CollectionName))
@ -46,7 +42,7 @@ func (ut *upsertTaskByStreamingService) Execute(ctx context.Context) error {
return nil
}
func (ut *upsertTaskByStreamingService) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) {
func (ut *upsertTask) packInsertMessage(ctx context.Context) ([]message.MutableMessage, error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy insertExecute upsert %d", ut.ID()))
defer tr.Elapse("insert execute done when insertExecute")
@ -93,7 +89,7 @@ func (ut *upsertTaskByStreamingService) packInsertMessage(ctx context.Context) (
return msgs, nil
}
func (it *upsertTaskByStreamingService) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) {
func (it *upsertTask) packDeleteMessage(ctx context.Context) ([]message.MutableMessage, error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy deleteExecute upsert %d", it.ID()))
collID := it.upsertMsg.DeleteMsg.CollectionID
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs

View File

@ -2286,184 +2286,6 @@ func checkDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
return nil
}
func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream.MsgStream, request interface{ GetBase() *commonpb.MsgBase }) {
if replicateMsgStream == nil || request == nil {
log.Ctx(ctx).Warn("replicate msg stream or request is nil", zap.Any("request", request))
return
}
msgBase := request.GetBase()
ts := msgBase.GetTimestamp()
if msgBase.GetReplicateInfo().GetIsReplicate() {
ts = msgBase.GetReplicateInfo().GetMsgTimestamp()
}
getBaseMsg := func(ctx context.Context, ts uint64) msgstream.BaseMsg {
return msgstream.BaseMsg{
Ctx: ctx,
HashValues: []uint32{0},
BeginTimestamp: ts,
EndTimestamp: ts,
}
}
var tsMsg msgstream.TsMsg
switch r := request.(type) {
case *milvuspb.AlterCollectionRequest:
tsMsg = &msgstream.AlterCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterCollectionRequest: r,
}
case *milvuspb.AlterCollectionFieldRequest:
tsMsg = &msgstream.AlterCollectionFieldMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterCollectionFieldRequest: r,
}
case *milvuspb.RenameCollectionRequest:
tsMsg = &msgstream.RenameCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
RenameCollectionRequest: r,
}
case *milvuspb.CreateDatabaseRequest:
tsMsg = &msgstream.CreateDatabaseMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateDatabaseRequest: r,
}
case *milvuspb.DropDatabaseRequest:
tsMsg = &msgstream.DropDatabaseMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropDatabaseRequest: r,
}
case *milvuspb.AlterDatabaseRequest:
tsMsg = &msgstream.AlterDatabaseMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterDatabaseRequest: r,
}
case *milvuspb.FlushRequest:
tsMsg = &msgstream.FlushMsg{
BaseMsg: getBaseMsg(ctx, ts),
FlushRequest: r,
}
case *milvuspb.LoadCollectionRequest:
tsMsg = &msgstream.LoadCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
LoadCollectionRequest: r,
}
case *milvuspb.ReleaseCollectionRequest:
tsMsg = &msgstream.ReleaseCollectionMsg{
BaseMsg: getBaseMsg(ctx, ts),
ReleaseCollectionRequest: r,
}
case *milvuspb.CreateIndexRequest:
tsMsg = &msgstream.CreateIndexMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateIndexRequest: r,
}
case *milvuspb.DropIndexRequest:
tsMsg = &msgstream.DropIndexMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropIndexRequest: r,
}
case *milvuspb.LoadPartitionsRequest:
tsMsg = &msgstream.LoadPartitionsMsg{
BaseMsg: getBaseMsg(ctx, ts),
LoadPartitionsRequest: r,
}
case *milvuspb.ReleasePartitionsRequest:
tsMsg = &msgstream.ReleasePartitionsMsg{
BaseMsg: getBaseMsg(ctx, ts),
ReleasePartitionsRequest: r,
}
case *milvuspb.AlterIndexRequest:
tsMsg = &msgstream.AlterIndexMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterIndexRequest: r,
}
case *milvuspb.CreateCredentialRequest:
tsMsg = &msgstream.CreateUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateCredentialRequest: r,
}
case *milvuspb.UpdateCredentialRequest:
tsMsg = &msgstream.UpdateUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
UpdateCredentialRequest: r,
}
case *milvuspb.DeleteCredentialRequest:
tsMsg = &msgstream.DeleteUserMsg{
BaseMsg: getBaseMsg(ctx, ts),
DeleteCredentialRequest: r,
}
case *milvuspb.CreateRoleRequest:
tsMsg = &msgstream.CreateRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateRoleRequest: r,
}
case *milvuspb.DropRoleRequest:
tsMsg = &msgstream.DropRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropRoleRequest: r,
}
case *milvuspb.OperateUserRoleRequest:
tsMsg = &msgstream.OperateUserRoleMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperateUserRoleRequest: r,
}
case *milvuspb.OperatePrivilegeRequest:
tsMsg = &msgstream.OperatePrivilegeMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperatePrivilegeRequest: r,
}
case *milvuspb.OperatePrivilegeV2Request:
tsMsg = &msgstream.OperatePrivilegeV2Msg{
BaseMsg: getBaseMsg(ctx, ts),
OperatePrivilegeV2Request: r,
}
case *milvuspb.CreatePrivilegeGroupRequest:
tsMsg = &msgstream.CreatePrivilegeGroupMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreatePrivilegeGroupRequest: r,
}
case *milvuspb.DropPrivilegeGroupRequest:
tsMsg = &msgstream.DropPrivilegeGroupMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropPrivilegeGroupRequest: r,
}
case *milvuspb.OperatePrivilegeGroupRequest:
tsMsg = &msgstream.OperatePrivilegeGroupMsg{
BaseMsg: getBaseMsg(ctx, ts),
OperatePrivilegeGroupRequest: r,
}
case *milvuspb.CreateAliasRequest:
tsMsg = &msgstream.CreateAliasMsg{
BaseMsg: getBaseMsg(ctx, ts),
CreateAliasRequest: r,
}
case *milvuspb.DropAliasRequest:
tsMsg = &msgstream.DropAliasMsg{
BaseMsg: getBaseMsg(ctx, ts),
DropAliasRequest: r,
}
case *milvuspb.AlterAliasRequest:
tsMsg = &msgstream.AlterAliasMsg{
BaseMsg: getBaseMsg(ctx, ts),
AlterAliasRequest: r,
}
default:
log.Warn("unknown request", zap.Any("request", request))
return
}
msgPack := &msgstream.MsgPack{
BeginTs: ts,
EndTs: ts,
Msgs: []msgstream.TsMsg{tsMsg},
}
msgErr := replicateMsgStream.Produce(ctx, msgPack)
// ignore the error if the msg stream failed to produce the msg,
// because it can be manually fixed in this error
if msgErr != nil {
log.Warn("send replicate msg failed", zap.Any("pack", msgPack), zap.Error(msgErr))
}
}
func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) {
if globalMetaCache != nil {
return globalMetaCache.GetCollectionSchema(ctx, dbName, colName)

View File

@ -2302,55 +2302,6 @@ func Test_validateMaxCapacityPerRow(t *testing.T) {
})
}
func TestSendReplicateMessagePack(t *testing.T) {
ctx := context.Background()
mockStream := msgstream.NewMockMsgStream(t)
t.Run("empty case", func(t *testing.T) {
SendReplicateMessagePack(ctx, nil, nil)
})
t.Run("produce fail", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(errors.New("produce error")).Once()
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{
Base: &commonpb.MsgBase{ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true,
MsgTimestamp: 100,
}},
})
})
t.Run("unknown request", func(t *testing.T) {
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ListDatabasesRequest{})
})
t.Run("normal case", func(t *testing.T) {
mockStream.EXPECT().Produce(mock.Anything, mock.Anything).Return(nil)
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterCollectionFieldRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.RenameCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateDatabaseRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropDatabaseRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.FlushRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleaseCollectionRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateIndexRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropIndexRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadPartitionsRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleasePartitionsRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateCredentialRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DeleteCredentialRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateRoleRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropRoleRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.OperateUserRoleRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.OperatePrivilegeRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateAliasRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropAliasRequest{})
SendReplicateMessagePack(ctx, mockStream, &milvuspb.AlterAliasRequest{})
})
}
func TestAppendUserInfoForRPC(t *testing.T) {
ctx := GetContext(context.Background(), "root:123456")
ctx = AppendUserInfoForRPC(ctx)

View File

@ -48,12 +48,10 @@ import (
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/searchutil/optimizers"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
@ -136,8 +134,6 @@ type shardDelegator struct {
// stream delete buffer
deleteMut sync.RWMutex
deleteBuffer deletebuffer.DeleteBuffer[*deletebuffer.Item]
// dispatcherClient msgdispatcher.Client
factory msgstream.Factory
sf conc.Singleflight[struct{}]
loader segments.Loader
@ -931,7 +927,7 @@ func (sd *shardDelegator) speedupGuranteeTS(
// when 1. streaming service is disable,
// 2. consistency level is not strong,
// 3. cannot speed iterator, because current client of milvus doesn't support shard level mvcc.
if !streamingutil.IsStreamingServiceEnabled() || isIterator || cl != commonpb.ConsistencyLevel_Strong || mvccTS != 0 {
if isIterator || cl != commonpb.ConsistencyLevel_Strong || mvccTS != 0 {
return guaranteeTS
}
// use the mvcc timestamp of the wal as the guarantee timestamp to make fast strong consistency search.
@ -1145,8 +1141,7 @@ func (sd *shardDelegator) loadPartitionStats(ctx context.Context, partStatsVersi
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader,
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
queryView *channelQueryView,
) (ShardDelegator, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
@ -1184,7 +1179,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
pkOracle: pkoracle.NewPkOracle(),
latestTsafe: atomic.NewUint64(startTs),
loader: loader,
factory: factory,
queryHook: queryHook,
chunkManager: chunkManager,
partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot),

View File

@ -19,7 +19,6 @@ package delegator
import (
"context"
"fmt"
"math/rand"
"runtime"
"time"
@ -40,8 +39,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
mqcommon "github.com/milvus-io/milvus/pkg/v2/mq/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
@ -771,120 +768,6 @@ func (sd *shardDelegator) loadStreamDelete(ctx context.Context,
return nil
}
func (sd *shardDelegator) createStreamFromMsgStream(ctx context.Context, position *msgpb.MsgPosition) (ch <-chan *msgstream.MsgPack, closer func(), err error) {
stream, err := sd.factory.NewTtMsgStream(ctx)
if err != nil {
return nil, nil, err
}
defer stream.Close()
vchannelName := position.ChannelName
pChannelName := funcutil.ToPhysicalChannel(vchannelName)
position.ChannelName = pChannelName
ts, _ := tsoutil.ParseTS(position.Timestamp)
// Random the subname in case we trying to load same delta at the same time
subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", paramtable.GetNodeID(), sd.collectionID, rand.Int())
log.Info("from dml check point load delete", zap.Any("position", position), zap.String("vChannel", vchannelName), zap.String("subName", subName), zap.Time("positionTs", ts))
err = stream.AsConsumer(context.TODO(), []string{pChannelName}, subName, mqcommon.SubscriptionPositionUnknown)
if err != nil {
return nil, stream.Close, err
}
err = stream.Seek(context.TODO(), []*msgpb.MsgPosition{position}, false)
if err != nil {
return nil, stream.Close, err
}
dispatcher := msgstream.NewSimpleMsgDispatcher(stream, func(pm msgstream.ConsumeMsg) bool {
if pm.GetType() != commonpb.MsgType_Delete || pm.GetVChannel() != vchannelName {
return false
}
return true
})
return dispatcher.Chan(), dispatcher.Close, nil
}
// Only used in test.
func (sd *shardDelegator) readDeleteFromMsgstream(ctx context.Context, position *msgpb.MsgPosition, safeTs uint64, candidate *pkoracle.BloomFilterSet) (*storage.DeleteData, error) {
log := sd.getLogger(ctx).With(
zap.String("channel", position.ChannelName),
zap.Int64("segmentID", candidate.ID()),
)
pChannelName := funcutil.ToPhysicalChannel(position.ChannelName)
var ch <-chan *msgstream.MsgPack
var closer func()
var err error
ch, closer, err = sd.createStreamFromMsgStream(ctx, position)
if closer != nil {
defer closer()
}
if err != nil {
return nil, err
}
start := time.Now()
result := &storage.DeleteData{}
hasMore := true
for hasMore {
select {
case <-ctx.Done():
log.Debug("read delta msg from seek position done", zap.Error(ctx.Err()))
return nil, ctx.Err()
case msgPack, ok := <-ch:
if !ok {
err = fmt.Errorf("stream channel closed, pChannelName=%v, msgID=%v", pChannelName, position.GetMsgID())
log.Warn("fail to read delta msg",
zap.String("pChannelName", pChannelName),
zap.Binary("msgID", position.GetMsgID()),
zap.Error(err),
)
return nil, err
}
if msgPack == nil {
continue
}
for _, tsMsg := range msgPack.Msgs {
if tsMsg.Type() == commonpb.MsgType_Delete {
dmsg := tsMsg.(*msgstream.DeleteMsg)
if dmsg.CollectionID != sd.collectionID || (dmsg.GetPartitionID() != common.AllPartitionsID && dmsg.GetPartitionID() != candidate.Partition()) {
continue
}
pks := storage.ParseIDs2PrimaryKeys(dmsg.GetPrimaryKeys())
batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt()
for idx := 0; idx < len(pks); idx += batchSize {
endIdx := idx + batchSize
if endIdx > len(pks) {
endIdx = len(pks)
}
lc := storage.NewBatchLocationsCache(pks[idx:endIdx])
hits := candidate.BatchPkExist(lc)
for i, hit := range hits {
if hit {
result.Pks = append(result.Pks, pks[idx+i])
result.Tss = append(result.Tss, dmsg.Timestamps[idx+i])
}
}
}
}
}
// reach safe ts
if safeTs <= msgPack.EndPositions[0].GetTimestamp() {
hasMore = false
}
}
}
log.Info("successfully read delete from stream ", zap.Duration("time spent", time.Since(start)))
return result, nil
}
// ReleaseSegments releases segments local or remotely depending on the target node.
func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error {
log := sd.getLogger(ctx)

View File

@ -192,11 +192,7 @@ func (s *DelegatorDataSuite) genCollectionWithFunction() {
}},
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err)
s.delegator = delegator.(*shardDelegator)
}
@ -214,11 +210,7 @@ func (s *DelegatorDataSuite) SetupTest() {
s.rootPath = s.Suite.T().Name()
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator)
s.Require().True(ok)
@ -806,11 +798,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
s.workerManager,
s.manager,
s.loader,
&msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err)
growing0 := segments.NewMockSegment(s.T())
@ -1524,39 +1512,6 @@ func (s *DelegatorDataSuite) TestLevel0Deletions() {
s.Empty(pks)
}
func (s *DelegatorDataSuite) TestReadDeleteFromMsgstream() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.mq.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil)
s.mq.EXPECT().Close()
ch := make(chan *msgstream.ConsumeMsgPack, 10)
s.mq.EXPECT().Chan().Return(ch)
oracle := pkoracle.NewBloomFilterSet(1, 1, commonpb.SegmentState_Sealed)
oracle.UpdateBloomFilter([]storage.PrimaryKey{storage.NewInt64PrimaryKey(1), storage.NewInt64PrimaryKey(2)})
baseMsg := &commonpb.MsgBase{MsgType: commonpb.MsgType_Delete}
datas := []*msgstream.MsgPack{
{EndTs: 10, EndPositions: []*msgpb.MsgPosition{{Timestamp: 10}}, Msgs: []msgstream.TsMsg{
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: 1, PrimaryKeys: storage.ParseInt64s2IDs(1), Timestamps: []uint64{1}}},
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: -1, PrimaryKeys: storage.ParseInt64s2IDs(2), Timestamps: []uint64{5}}},
// invalid msg because partition wrong
&msgstream.DeleteMsg{DeleteRequest: &msgpb.DeleteRequest{Base: baseMsg, CollectionID: s.collectionID, PartitionID: 2, PrimaryKeys: storage.ParseInt64s2IDs(1), Timestamps: []uint64{10}}},
}},
}
for _, data := range datas {
ch <- msgstream.BuildConsumeMsgPack(data)
}
result, err := s.delegator.readDeleteFromMsgstream(ctx, &msgpb.MsgPosition{Timestamp: 0}, 10, oracle)
s.NoError(err)
s.Equal(2, len(result.Pks))
}
func (s *DelegatorDataSuite) TestDelegatorData_ExcludeSegments() {
s.delegator.AddExcludedSegments(map[int64]uint64{
1: 3,

View File

@ -19,6 +19,7 @@ package delegator
import (
"context"
"io"
"os"
"testing"
"time"
@ -33,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage"
@ -51,6 +53,12 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestMain(m *testing.M) {
streaming.SetupNoopWALForTest()
os.Exit(m.Run())
}
type DelegatorSuite struct {
suite.Suite
@ -164,11 +172,7 @@ func (s *DelegatorSuite) SetupTest() {
var err error
// s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader)
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
}
@ -203,11 +207,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
}},
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Error(err)
})
@ -246,11 +246,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
}},
}, nil, &querypb.LoadMetaInfo{SchemaVersion: tsoutil.ComposeTSByTime(time.Now(), 0)})
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
_, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err)
})
}
@ -1410,11 +1406,7 @@ func (s *DelegatorSuite) TestUpdateSchema() {
func (s *DelegatorSuite) ResetDelegator() {
var err error
s.delegator.Close()
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.delegator, err = NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
}

View File

@ -151,11 +151,7 @@ func (s *StreamingForwardSuite) SetupTest() {
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator)
@ -394,11 +390,7 @@ func (s *GrowingMergeL0Suite) SetupTest() {
chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath)
s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background())
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, &msgstream.MockMqFactory{
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.loader, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator)

View File

@ -65,7 +65,6 @@ import (
"github.com/milvus-io/milvus/internal/util/searchutil/scheduler"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/config"
@ -528,11 +527,7 @@ func (node *QueryNode) Init() error {
node.manager = segments.NewManager()
node.loader = segments.NewLoader(node.ctx, node.manager, node.chunkManager)
node.manager.SetLoader(node.loader)
if streamingutil.IsStreamingServiceEnabled() {
node.dispClient = msgdispatcher.NewClientWithIncludeSkipWhenSplit(streaming.NewDelegatorMsgstreamFactory(), typeutil.QueryNodeRole, node.GetNodeID())
} else {
node.dispClient = msgdispatcher.NewClient(node.factory, typeutil.QueryNodeRole, node.GetNodeID())
}
node.dispClient = msgdispatcher.NewClientWithIncludeSkipWhenSplit(streaming.NewDelegatorMsgstreamFactory(), typeutil.QueryNodeRole, node.GetNodeID())
// init pipeline manager
node.pipelineManager = pipeline.NewManager(node.manager, node.dispClient, node.delegators)

View File

@ -267,7 +267,6 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
node.clusterManager,
node.manager,
node.loader,
node.factory,
channel.GetSeekPosition().GetTimestamp(),
node.queryHook,
node.chunkManager,

View File

@ -49,12 +49,12 @@ import (
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -68,7 +68,6 @@ import (
type ServiceSuite struct {
suite.Suite
// Data
msgChan chan *msgstream.ConsumeMsgPack
collectionID int64
collectionName string
schema *schemapb.CollectionSchema
@ -92,8 +91,7 @@ type ServiceSuite struct {
chunkManagerFactory *storage.ChunkManagerFactory
// Mock
factory *dependency.MockFactory
msgStream *msgstream.MockMsgStream
factory *dependency.MockFactory
}
func (suite *ServiceSuite) SetupSuite() {
@ -129,7 +127,6 @@ func (suite *ServiceSuite) SetupTest() {
ctx := context.Background()
// init mock
suite.factory = dependency.NewMockFactory(suite.T())
suite.msgStream = msgstream.NewMockMsgStream(suite.T())
// TODO:: cpp chunk manager not support local chunk manager
paramtable.Get().Save(paramtable.Get().LocalStorageCfg.Path.Key, suite.T().TempDir())
// suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test"))
@ -315,13 +312,6 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
}
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()
// watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
@ -367,13 +357,6 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
}
// mocks
suite.factory.EXPECT().NewTtMsgStream(mock.Anything).Return(suite.msgStream, nil).Maybe()
suite.msgStream.EXPECT().AsConsumer(mock.Anything, []string{suite.pchannel}, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
suite.msgStream.EXPECT().Chan().Return(suite.msgChan).Maybe()
suite.msgStream.EXPECT().Close().Maybe()
// watchDmChannels
status, err := suite.node.WatchDmChannels(ctx, req)
suite.NoError(err)
@ -2437,6 +2420,13 @@ func TestQueryNodeService(t *testing.T) {
local.EXPECT().GetLatestMVCCTimestampIfLocal(mock.Anything, mock.Anything).Return(0, nil).Maybe()
local.EXPECT().GetMetricsIfLocal(mock.Anything).Return(&types.StreamingNodeMetrics{}, nil).Maybe()
wal.EXPECT().Local().Return(local).Maybe()
wal.EXPECT().WALName().Return(rmq.WALName).Maybe()
scanner := mock_streaming.NewMockScanner(t)
scanner.EXPECT().Done().Return(make(chan struct{})).Maybe()
scanner.EXPECT().Error().Return(nil).Maybe()
scanner.EXPECT().Close().Return().Maybe()
wal.EXPECT().Read(mock.Anything, mock.Anything).Return(scanner).Maybe()
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()

View File

@ -1,98 +0,0 @@
package rootcoord
import (
"context"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
var _ task = (*broadcastTask)(nil)
// BroadcastTask is used to implement the broadcast operation based on the msgstream
// by using the streaming service interface.
// msgstream will be deprecated since 2.6.0 with streaming service, so those code will be removed in the future version.
type broadcastTask struct {
baseTask
msgs []message.MutableMessage // The message wait for broadcast
walName string
resultFuture *syncutil.Future[types.AppendResponses]
}
func (b *broadcastTask) Execute(ctx context.Context) error {
result := types.NewAppendResponseN(len(b.msgs))
defer func() {
b.resultFuture.Set(result)
}()
for idx, msg := range b.msgs {
tsMsg, err := adaptor.NewMsgPackFromMutableMessageV1(msg)
tsMsg.SetTs(b.ts) // overwrite the ts.
if err != nil {
result.FillResponseAtIdx(types.AppendResponse{Error: err}, idx)
return err
}
pchannel := funcutil.ToPhysicalChannel(msg.VChannel())
msgID, err := b.core.chanTimeTick.broadcastMarkDmlChannels([]string{pchannel}, &msgstream.MsgPack{
BeginTs: b.ts,
EndTs: b.ts,
Msgs: []msgstream.TsMsg{tsMsg},
})
if err != nil {
result.FillResponseAtIdx(types.AppendResponse{Error: err}, idx)
continue
}
result.FillResponseAtIdx(types.AppendResponse{
AppendResult: &types.AppendResult{
MessageID: adaptor.MustGetMessageIDFromMQWrapperIDBytes(b.walName, msgID[pchannel]),
TimeTick: b.ts,
},
}, idx)
}
return result.UnwrapFirstError()
}
func newMsgStreamAppendOperator(c *Core) *msgstreamAppendOperator {
return &msgstreamAppendOperator{
core: c,
walName: util.MustSelectWALName(),
}
}
// msgstreamAppendOperator the code of streamingcoord to make broadcast available on the legacy msgstream.
// Because msgstream is bound to the rootcoord task, so we transfer each broadcast operation into a ddl task.
// to make sure the timetick rule.
// The Msgstream will be deprecated since 2.6.0, so we make a single module to hold it.
type msgstreamAppendOperator struct {
core *Core
walName string
}
// AppendMessages implements the AppendOperator interface for broadcaster service at streaming service.
func (m *msgstreamAppendOperator) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
t := &broadcastTask{
baseTask: newBaseTask(ctx, m.core),
msgs: msgs,
walName: m.walName,
resultFuture: syncutil.NewFuture[types.AppendResponses](),
}
if err := m.core.scheduler.AddTask(t); err != nil {
resp := types.NewAppendResponseN(len(msgs))
resp.FillAllError(err)
return resp
}
result, err := t.resultFuture.GetWithContext(ctx)
if err != nil {
resp := types.NewAppendResponseN(len(msgs))
resp.FillAllError(err)
return resp
}
return result
}

View File

@ -44,7 +44,6 @@ import (
kvmetastore "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model"
streamingcoord "github.com/milvus-io/milvus/internal/streamingcoord/server"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
tso2 "github.com/milvus-io/milvus/internal/tso"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
@ -694,11 +693,6 @@ func (c *Core) startInternal() error {
c.UpdateStateCode(commonpb.StateCode_Healthy)
sessionutil.SaveServerInfo(typeutil.MixCoordRole, c.session.GetServerID())
log.Info("rootcoord startup successfully")
// regster the core as a appendoperator for broadcast service.
// TODO: should be removed at 2.6.0.
// Add the wal accesser to the broadcaster registry for making broadcast operation.
registry.Register(registry.AppendOperatorTypeMsgstream, newMsgStreamAppendOperator(c))
return nil
}

View File

@ -12,7 +12,6 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/client/assignment"
"github.com/milvus-io/milvus/internal/streamingcoord/client/broadcast"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/balancer/picker"
streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
@ -76,11 +75,8 @@ func NewClient(etcdCli *clientv3.Client) Client {
dialOptions...,
)
})
var assignmentServiceImpl *assignment.AssignmentServiceImpl
if streamingutil.IsStreamingServiceEnabled() {
assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient)
assignmentServiceImpl = assignment.NewAssignmentService(assignmentService)
}
assignmentService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordAssignmentServiceClient)
assignmentServiceImpl := assignment.NewAssignmentService(assignmentService)
broadcastService := lazygrpc.WithServiceCreator(conn, streamingpb.NewStreamingCoordBroadcastServiceClient)
return &clientImpl{
conn: conn,

View File

@ -1,14 +0,0 @@
package broadcaster
import (
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/util/streamingutil"
)
// NewAppendOperator creates an append operator to handle the incoming messages for broadcaster.
func NewAppendOperator() AppendOperator {
if streamingutil.IsStreamingServiceEnabled() {
return streaming.WAL()
}
return nil
}

View File

@ -3,7 +3,6 @@ package broadcaster
import (
"context"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)
@ -19,4 +18,9 @@ type Broadcaster interface {
Close()
}
type AppendOperator = registry.AppendOperator
// AppendOperator is used to append messages, there's only two implement of this interface:
// 1. streaming.WAL()
// 2. old msgstream interface [deprecated]
type AppendOperator interface {
AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses
}

View File

@ -22,7 +22,6 @@ import (
func RecoverBroadcaster(
ctx context.Context,
appendOperator *syncutil.Future[AppendOperator],
) (Broadcaster, error) {
tasks, err := resource.Resource().StreamingCatalog().ListBroadcastTask(ctx)
if err != nil {
@ -38,7 +37,6 @@ func RecoverBroadcaster(
backoffChan: make(chan *pendingBroadcastTask),
pendingChan: make(chan *pendingBroadcastTask),
workerChan: make(chan *pendingBroadcastTask),
appendOperator: appendOperator,
}
go b.execute()
return b, nil
@ -54,7 +52,6 @@ type broadcasterImpl struct {
pendingChan chan *pendingBroadcastTask
backoffChan chan *pendingBroadcastTask
workerChan chan *pendingBroadcastTask
appendOperator *syncutil.Future[AppendOperator] // TODO: we can remove those lazy future in 2.6.0, by remove the msgstream broadcaster.
}
// Broadcast broadcasts the message to all channels.
@ -140,14 +137,6 @@ func (b *broadcasterImpl) execute() {
b.Logger().Info("broadcaster execute exit")
}()
// Wait for appendOperator ready
appendOperator, err := b.appendOperator.GetWithContext(b.backgroundTaskNotifier.Context())
if err != nil {
b.Logger().Info("broadcaster is closed before appendOperator ready")
return
}
b.Logger().Info("broadcaster appendOperator ready, begin to start workers and dispatch")
// Start n workers to handle the broadcast task.
wg := sync.WaitGroup{}
for i := 0; i < workers; i++ {
@ -156,7 +145,7 @@ func (b *broadcasterImpl) execute() {
wg.Add(1)
go func() {
defer wg.Done()
b.worker(i, appendOperator)
b.worker(i)
}()
}
defer wg.Wait()
@ -205,7 +194,7 @@ func (b *broadcasterImpl) dispatch() {
}
}
func (b *broadcasterImpl) worker(no int, appendOperator AppendOperator) {
func (b *broadcasterImpl) worker(no int) {
logger := b.Logger().With(zap.Int("workerNo", no))
defer func() {
logger.Info("broadcaster worker exit")
@ -216,7 +205,7 @@ func (b *broadcasterImpl) worker(no int, appendOperator AppendOperator) {
case <-b.backgroundTaskNotifier.Context().Done():
return
case task := <-b.workerChan:
if err := task.Execute(b.backgroundTaskNotifier.Context(), appendOperator); err != nil {
if err := task.Execute(b.backgroundTaskNotifier.Context()); err != nil {
// If the task is not done, repush it into pendings and retry infinitely.
select {
case <-b.backgroundTaskNotifier.Context().Done():

View File

@ -12,8 +12,9 @@ import (
"go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
internaltypes "github.com/milvus-io/milvus/internal/types"
@ -76,8 +77,8 @@ func TestBroadcaster(t *testing.T) {
resource.InitForTest(resource.OptStreamingCatalog(meta), resource.OptMixCoordClient(f))
fbc := syncutil.NewFuture[Broadcaster]()
operator, appended := createOpeartor(t, fbc)
bc, err := RecoverBroadcaster(context.Background(), operator)
appended := createOpeartor(t, fbc)
bc, err := RecoverBroadcaster(context.Background())
fbc.Set(bc)
assert.NoError(t, err)
assert.NotNil(t, bc)
@ -135,10 +136,10 @@ func ack(broadcaster Broadcaster, broadcastID uint64, vchannel string) {
}
}
func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) (*syncutil.Future[AppendOperator], *atomic.Int64) {
func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) *atomic.Int64 {
id := atomic.NewInt64(1)
appended := atomic.NewInt64(0)
operator := mock_broadcaster.NewMockAppendOperator(t)
operator := mock_streaming.NewMockWALAccesser(t)
f := func(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses {
resps := types.AppendResponses{
Responses: make([]types.AppendResponse, len(msgs)),
@ -174,9 +175,8 @@ func createOpeartor(t *testing.T, broadcaster *syncutil.Future[Broadcaster]) (*s
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
operator.EXPECT().AppendMessages(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(f)
fOperator := syncutil.NewFuture[AppendOperator]()
fOperator.Set(operator)
return fOperator, appended
streaming.SetWALForTest(operator)
return appended
}
func createNewBroadcastMsg(vchannels []string, rks ...message.ResourceKey) message.BroadcastMutableMessage {

View File

@ -1,42 +0,0 @@
package registry
import (
"context"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
type AppendOperatorType int
const (
AppendOperatorTypeMsgstream AppendOperatorType = iota + 1
AppendOperatorTypeStreaming
)
var localRegistry = make(map[AppendOperatorType]*syncutil.Future[AppendOperator])
// AppendOperator is used to append messages, there's only two implement of this interface:
// 1. streaming.WAL()
// 2. old msgstream interface
type AppendOperator interface {
AppendMessages(ctx context.Context, msgs ...message.MutableMessage) types.AppendResponses
}
func init() {
localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]()
localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]()
}
func Register(typ AppendOperatorType, op AppendOperator) {
localRegistry[typ].Set(op)
}
func GetAppendOperator() *syncutil.Future[AppendOperator] {
if streamingutil.IsStreamingServiceEnabled() {
return localRegistry[AppendOperatorTypeStreaming]
}
return localRegistry[AppendOperatorTypeMsgstream]
}

View File

@ -3,12 +3,7 @@
package registry
import "github.com/milvus-io/milvus/pkg/v2/util/syncutil"
func ResetRegistration() {
localRegistry = make(map[AppendOperatorType]*syncutil.Future[AppendOperator])
localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]()
localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]()
resetMessageAckCallbacks()
resetMessageCheckCallbacks()
}

View File

@ -7,6 +7,7 @@ import (
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
@ -49,7 +50,7 @@ type pendingBroadcastTask struct {
// Execute reexecute the task, return nil if the task is done, otherwise not done.
// Execute can be repeated called until the task is done.
// Same semantics as the `Poll` operation in eventloop.
func (b *pendingBroadcastTask) Execute(ctx context.Context, operator AppendOperator) error {
func (b *pendingBroadcastTask) Execute(ctx context.Context) error {
if err := b.broadcastTask.InitializeRecovery(ctx); err != nil {
b.Logger().Warn("broadcast task initialize recovery failed", zap.Error(err))
b.UpdateInstantWithNextBackOff()
@ -58,7 +59,7 @@ func (b *pendingBroadcastTask) Execute(ctx context.Context, operator AppendOpera
if len(b.pendingMessages) > 0 {
b.Logger().Debug("broadcast task is polling to make sent...", zap.Int("pendingMessages", len(b.pendingMessages)))
resps := operator.AppendMessages(ctx, b.pendingMessages...)
resps := streaming.WAL().AppendMessages(ctx, b.pendingMessages...)
newPendings := make([]message.MutableMessage, 0)
for idx, resp := range resps.Responses {
if resp.Error != nil {

View File

@ -9,7 +9,6 @@ import (
"github.com/milvus-io/milvus/internal/streamingnode/client/manager"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/idalloc"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
@ -55,10 +54,8 @@ func Init(opts ...optResourceInit) {
assertNotNil(newR.MixCoordClient())
assertNotNil(newR.ETCD())
assertNotNil(newR.StreamingCatalog())
if streamingutil.IsStreamingServiceEnabled() {
newR.streamingNodeManagerClient = manager.NewManagerClient(newR.etcdClient)
assertNotNil(newR.StreamingNodeManagerClient())
}
newR.streamingNodeManagerClient = manager.NewManagerClient(newR.etcdClient)
assertNotNil(newR.StreamingNodeManagerClient())
r = newR
}

View File

@ -10,11 +10,9 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
_ "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/policy" // register the balancer policy
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/streamingcoord/server/service"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
@ -53,27 +51,25 @@ func (s *Server) Start(ctx context.Context) (err error) {
// initBasicComponent initialize all underlying dependency for streamingcoord.
func (s *Server) initBasicComponent(ctx context.Context) (err error) {
futures := make([]*conc.Future[struct{}], 0)
if streamingutil.IsStreamingServiceEnabled() {
futures = append(futures, conc.Go(func() (struct{}, error) {
s.logger.Info("start recovery balancer...")
// Read new incoming topics from configuration, and register it into balancer.
newIncomingTopics := util.GetAllTopicsFromConfiguration()
balancer, err := balancer.RecoverBalancer(ctx, newIncomingTopics.Collect()...)
if err != nil {
s.logger.Warn("recover balancer failed", zap.Error(err))
return struct{}{}, err
}
s.balancer.Set(balancer)
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
s.logger.Info("recover balancer done")
return struct{}{}, nil
}))
}
futures = append(futures, conc.Go(func() (struct{}, error) {
s.logger.Info("start recovery balancer...")
// Read new incoming topics from configuration, and register it into balancer.
newIncomingTopics := util.GetAllTopicsFromConfiguration()
balancer, err := balancer.RecoverBalancer(ctx, newIncomingTopics.Collect()...)
if err != nil {
s.logger.Warn("recover balancer failed", zap.Error(err))
return struct{}{}, err
}
s.balancer.Set(balancer)
snmanager.StaticStreamingNodeManager.SetBalancerReady(balancer)
s.logger.Info("recover balancer done")
return struct{}{}, nil
}))
// The broadcaster of msgstream is implemented on current streamingcoord to reduce the development complexity.
// So we need to recover it.
futures = append(futures, conc.Go(func() (struct{}, error) {
s.logger.Info("start recovery broadcaster...")
broadcaster, err := broadcaster.RecoverBroadcaster(ctx, registry.GetAppendOperator())
broadcaster, err := broadcaster.RecoverBroadcaster(ctx)
if err != nil {
s.logger.Warn("recover broadcaster failed", zap.Error(err))
return struct{}{}, err
@ -87,9 +83,7 @@ func (s *Server) initBasicComponent(ctx context.Context) (err error) {
// RegisterGRPCService register all grpc service to grpc server.
func (s *Server) RegisterGRPCService(grpcServer *grpc.Server) {
if streamingutil.IsStreamingServiceEnabled() {
streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService)
}
streamingpb.RegisterStreamingCoordAssignmentServiceServer(grpcServer, s.assignmentService)
streamingpb.RegisterStreamingCoordBroadcastServiceServer(grpcServer, s.broadcastService)
}

View File

@ -84,15 +84,6 @@ type DataNodeComponent interface {
// SetEtcdClient set etcd client for DataNode
SetEtcdClient(etcdClient *clientv3.Client)
// SetMixCoordClient set SetMixCoordClient for DataNode
// `mixCoord` is a client of root coordinator.
//
// Return a generic error in status:
// If the mixCoord is nil or the mixCoord has been set before.
// Return nil in status:
// The mixCoord is not nil.
SetMixCoordClient(mixCoord MixCoordClient) error
}
// DataCoordClient is the client interface for datacoord server
@ -196,9 +187,6 @@ type ProxyComponent interface {
SetAddress(address string)
GetAddress() string
// SetEtcdClient set EtcdClient for Proxy
// `etcdClient` is a client of etcd
SetEtcdClient(etcdClient *clientv3.Client)
// SetMixCoordClient set MixCoord for Proxy
// `mixCoord` is a client of mix coordinator.