diff --git a/internal/msgstream/rmqms/rmq_msgstream.go b/internal/msgstream/rmqms/rmq_msgstream.go index eaa35f6f81..f5db14bcbf 100644 --- a/internal/msgstream/rmqms/rmq_msgstream.go +++ b/internal/msgstream/rmqms/rmq_msgstream.go @@ -2,14 +2,13 @@ package rmqms import ( "context" + "errors" "log" "path/filepath" "reflect" "strconv" "sync" - "errors" - "github.com/gogo/protobuf/proto" "github.com/zilliztech/milvus-distributed/internal/msgstream/util" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" @@ -329,6 +328,7 @@ func (ms *RmqTtMsgStream) AsConsumer(channels []string, } consumer.MsgNum = make(chan int, ms.rmqBufSize) ms.consumers = append(ms.consumers, *consumer) + ms.consumerChannels = append(ms.consumerChannels, consumer.ChannelName) ms.consumerReflects = append(ms.consumerReflects, reflect.SelectCase{ Dir: reflect.SelectRecv, Chan: reflect.ValueOf(consumer.MsgNum), @@ -376,6 +376,9 @@ func (ms *RmqTtMsgStream) bufMsgPackToChannel() { msgPositions := make([]*msgstream.MsgPosition, 0) ms.unsolvedMutex.Lock() for consumer, msgs := range ms.unsolvedBuf { + if len(msgs) == 0 { + continue + } tempBuffer := make([]TsMsg, 0) var timeTickMsg TsMsg for _, v := range msgs { @@ -479,6 +482,10 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { for index, channel := range ms.consumerChannels { if filepath.Base(channel) == filepath.Base(mp.ChannelName) { consumer = ms.consumers[index] + if len(mp.MsgID) == 0 { + msgID = -1 + break + } seekMsgID, err := strconv.ParseInt(mp.MsgID, 10, 64) if err != nil { return err @@ -491,51 +498,94 @@ func (ms *RmqTtMsgStream) Seek(mp *msgstream.MsgPosition) error { if err != nil { return err } + if msgID == -1 { + return nil + } ms.unsolvedMutex.Lock() ms.unsolvedBuf[consumer] = make([]TsMsg, 0) + // When rmq seek is called, msgNum can't be used before current msgs all consumed, because + // new msgNum is not generated. So just try to consume msgs for { - select { - case <-ms.ctx.Done(): - return nil - case num, ok := <-consumer.MsgNum: - if !ok { - return errors.New("consumer closed") - } - rmqMsg, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, num) + rmqMsg, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, 1) + if err != nil { + log.Printf("Failed to consume message in rocksmq, error = %v", err) + } + if len(rmqMsg) == 0 { + break + } else { + headerMsg := commonpb.MsgHeader{} + err := proto.Unmarshal(rmqMsg[0].Payload, &headerMsg) if err != nil { - log.Printf("Failed to consume message in rocksmq, error = %v", err) + log.Printf("Failed to unmarshal message header, error = %v", err) + return err + } + tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg[0].Payload, headerMsg.Base.MsgType) + if err != nil { + log.Printf("Failed to unmarshal tsMsg, error = %v", err) + return err + } + + if headerMsg.Base.MsgType == commonpb.MsgType_kTimeTick { + if tsMsg.BeginTs() >= mp.Timestamp { + ms.unsolvedMutex.Unlock() + return nil + } continue } - - for j := 0; j < len(rmqMsg); j++ { - headerMsg := commonpb.MsgHeader{} - err := proto.Unmarshal(rmqMsg[j].Payload, &headerMsg) - if err != nil { - log.Printf("Failed to unmarshal message header, error = %v", err) - } - tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg[j].Payload, headerMsg.Base.MsgType) - if err != nil { - log.Printf("Failed to unmarshal tsMsg, error = %v", err) - } - - if headerMsg.Base.MsgType == commonpb.MsgType_kTimeTick { - if tsMsg.BeginTs() >= mp.Timestamp { - ms.unsolvedMutex.Unlock() - return nil - } - continue - } - if tsMsg.BeginTs() > mp.Timestamp { - tsMsg.SetPosition(&msgstream.MsgPosition{ - ChannelName: filepath.Base(consumer.ChannelName), - MsgID: strconv.Itoa(int(rmqMsg[j].MsgID)), - }) - ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) - } + if tsMsg.BeginTs() > mp.Timestamp { + tsMsg.SetPosition(&msgstream.MsgPosition{ + ChannelName: filepath.Base(consumer.ChannelName), + MsgID: strconv.Itoa(int(rmqMsg[0].MsgID)), + }) + ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) } } } + return nil + + //for { + // select { + // case <-ms.ctx.Done(): + // return nil + // case num, ok := <-consumer.MsgNum: + // if !ok { + // return errors.New("consumer closed") + // } + // rmqMsg, err := rocksmq.Rmq.Consume(consumer.GroupName, consumer.ChannelName, num) + // if err != nil { + // log.Printf("Failed to consume message in rocksmq, error = %v", err) + // continue + // } + // + // for j := 0; j < len(rmqMsg); j++ { + // headerMsg := commonpb.MsgHeader{} + // err := proto.Unmarshal(rmqMsg[j].Payload, &headerMsg) + // if err != nil { + // log.Printf("Failed to unmarshal message header, error = %v", err) + // } + // tsMsg, err := ms.unmarshal.Unmarshal(rmqMsg[j].Payload, headerMsg.Base.MsgType) + // if err != nil { + // log.Printf("Failed to unmarshal tsMsg, error = %v", err) + // } + // + // if headerMsg.Base.MsgType == commonpb.MsgType_kTimeTick { + // if tsMsg.BeginTs() >= mp.Timestamp { + // ms.unsolvedMutex.Unlock() + // return nil + // } + // continue + // } + // if tsMsg.BeginTs() > mp.Timestamp { + // tsMsg.SetPosition(&msgstream.MsgPosition{ + // ChannelName: filepath.Base(consumer.ChannelName), + // MsgID: strconv.Itoa(int(rmqMsg[j].MsgID)), + // }) + // ms.unsolvedBuf[consumer] = append(ms.unsolvedBuf[consumer], tsMsg) + // } + // } + // } + //} } func checkTimeTickMsg(msg map[rocksmq.Consumer]Timestamp, diff --git a/internal/types/types.go b/internal/types/types.go index f7f7a3faf2..dfbbdd82c0 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -25,7 +25,7 @@ type Component interface { GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) } -type DataNodeService interface { +type DataNode interface { Component WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelRequest) (*commonpb.Status, error) @@ -51,7 +51,7 @@ type DataService interface { GetSegmentInfo(ctx context.Context, req *datapb.SegmentInfoRequest) (*datapb.SegmentInfoResponse, error) } -type IndexNodeService interface { +type IndexNode interface { Component TimeTickProvider @@ -110,41 +110,44 @@ type MasterService interface { ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentRequest) (*milvuspb.ShowSegmentResponse, error) } -type ProxyNodeService interface { +type ProxyNode interface { Component InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) - CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) - DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) - HasCollection(ctx context.Context, request *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) - LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) - ReleaseCollection(ctx context.Context, request *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) - DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) - GetCollectionStatistics(ctx context.Context, request *milvuspb.CollectionStatsRequest) (*milvuspb.CollectionStatsResponse, error) - ShowCollections(ctx context.Context, request *milvuspb.ShowCollectionRequest) (*milvuspb.ShowCollectionResponse, error) + //TODO: move to milvus service + /* + CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) + DropCollection(ctx context.Context, request *milvuspb.DropCollectionRequest) (*commonpb.Status, error) + HasCollection(ctx context.Context, request *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) + LoadCollection(ctx context.Context, request *milvuspb.LoadCollectionRequest) (*commonpb.Status, error) + ReleaseCollection(ctx context.Context, request *milvuspb.ReleaseCollectionRequest) (*commonpb.Status, error) + DescribeCollection(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) + GetCollectionStatistics(ctx context.Context, request *milvuspb.CollectionStatsRequest) (*milvuspb.CollectionStatsResponse, error) + ShowCollections(ctx context.Context, request *milvuspb.ShowCollectionRequest) (*milvuspb.ShowCollectionResponse, error) - CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) - DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error) - HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) - LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitonRequest) (*commonpb.Status, error) - ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionRequest) (*commonpb.Status, error) - GetPartitionStatistics(ctx context.Context, request *milvuspb.PartitionStatsRequest) (*milvuspb.PartitionStatsResponse, error) - ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionRequest) (*milvuspb.ShowPartitionResponse, error) + CreatePartition(ctx context.Context, request *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) + DropPartition(ctx context.Context, request *milvuspb.DropPartitionRequest) (*commonpb.Status, error) + HasPartition(ctx context.Context, request *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) + LoadPartitions(ctx context.Context, request *milvuspb.LoadPartitonRequest) (*commonpb.Status, error) + ReleasePartitions(ctx context.Context, request *milvuspb.ReleasePartitionRequest) (*commonpb.Status, error) + GetPartitionStatistics(ctx context.Context, request *milvuspb.PartitionStatsRequest) (*milvuspb.PartitionStatsResponse, error) + ShowPartitions(ctx context.Context, request *milvuspb.ShowPartitionRequest) (*milvuspb.ShowPartitionResponse, error) - CreateIndex(ctx context.Context, request *milvuspb.CreateIndexRequest) (*commonpb.Status, error) - DescribeIndex(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) - GetIndexState(ctx context.Context, request *milvuspb.IndexStateRequest) (*milvuspb.IndexStateResponse, error) - DropIndex(ctx context.Context, request *milvuspb.DropIndexRequest) (*commonpb.Status, error) + CreateIndex(ctx context.Context, request *milvuspb.CreateIndexRequest) (*commonpb.Status, error) + DescribeIndex(ctx context.Context, request *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) + GetIndexState(ctx context.Context, request *milvuspb.IndexStateRequest) (*milvuspb.IndexStateResponse, error) + DropIndex(ctx context.Context, request *milvuspb.DropIndexRequest) (*commonpb.Status, error) - Insert(ctx context.Context, request *milvuspb.InsertRequest) (*milvuspb.InsertResponse, error) - Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) - Flush(ctx context.Context, request *milvuspb.FlushRequest) (*commonpb.Status, error) + Insert(ctx context.Context, request *milvuspb.InsertRequest) (*milvuspb.InsertResponse, error) + Search(ctx context.Context, request *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) + Flush(ctx context.Context, request *milvuspb.FlushRequest) (*commonpb.Status, error) - GetDdChannel(ctx context.Context, request *commonpb.Empty) (*milvuspb.StringResponse, error) + GetDdChannel(ctx context.Context, request *commonpb.Empty) (*milvuspb.StringResponse, error) - GetQuerySegmentInfo(ctx context.Context, req *milvuspb.QuerySegmentInfoRequest) (*milvuspb.QuerySegmentInfoResponse, error) - GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.PersistentSegmentInfoRequest) (*milvuspb.PersistentSegmentInfoResponse, error) + GetQuerySegmentInfo(ctx context.Context, req *milvuspb.QuerySegmentInfoRequest) (*milvuspb.QuerySegmentInfoResponse, error) + GetPersistentSegmentInfo(ctx context.Context, req *milvuspb.PersistentSegmentInfoRequest) (*milvuspb.PersistentSegmentInfoResponse, error) + */ } type ProxyService interface { @@ -155,7 +158,7 @@ type ProxyService interface { InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) } -type QueryNodeService interface { +type QueryNode interface { Component TimeTickProvider