diff --git a/internal/msgstream/pulsarms/pulsar_msgstream.go b/internal/msgstream/pulsarms/pulsar_msgstream.go index d201934800..d4fca08978 100644 --- a/internal/msgstream/pulsarms/pulsar_msgstream.go +++ b/internal/msgstream/pulsarms/pulsar_msgstream.go @@ -721,11 +721,16 @@ func (ms *PulsarTtMsgStream) Seek(mp *internalpb2.MsgPosition) error { var messageID MessageID for index, channel := range ms.consumerChannels { if filepath.Base(channel) == filepath.Base(mp.ChannelName) { + consumer = ms.consumers[index] + if len(mp.MsgID) == 0 { + // TODO:: collection should has separate channels; otherwise will consume redundant msg + messageID = pulsar.EarliestMessageID() + break + } seekMsgID, err := typeutil.StringToPulsarMsgID(mp.MsgID) if err != nil { return err } - consumer = ms.consumers[index] messageID = seekMsgID break } @@ -736,6 +741,9 @@ func (ms *PulsarTtMsgStream) Seek(mp *internalpb2.MsgPosition) error { if err != nil { return err } + if messageID == nil { + return nil + } ms.unsolvedMutex.Lock() ms.unsolvedBuf[consumer] = make([]TsMsg, 0) diff --git a/internal/querynode/data_sync_service.go b/internal/querynode/data_sync_service.go index 909e4a116c..a66632cdf5 100644 --- a/internal/querynode/data_sync_service.go +++ b/internal/querynode/data_sync_service.go @@ -5,6 +5,7 @@ import ( "log" "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb2" "github.com/zilliztech/milvus-distributed/internal/util/flowgraph" ) @@ -95,3 +96,11 @@ func (dsService *dataSyncService) initNodes() { log.Fatal("set edges failed in node:", serviceTimeNode.Name()) } } + +func (dsService *dataSyncService) seekSegment(position *internalpb2.MsgPosition) error { + err := dsService.dmStream.Seek(position) + if err != nil { + return err + } + return nil +} diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 34f22c6e4e..6e57945bb7 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -403,55 +403,64 @@ func (node *QueryNode) LoadSegments(in *queryPb.LoadSegmentRequest) (*commonpb.S fieldIDs := in.FieldIDs schema := in.Schema + fmt.Println("query node load segment ,info = ", in) + + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_SUCCESS, + } hasCollection := node.replica.hasCollection(collectionID) hasPartition := node.replica.hasPartition(partitionID) if !hasCollection { err := node.replica.addCollection(collectionID, schema) if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: err.Error(), - } + status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR + status.Reason = err.Error() return status, err } } if !hasPartition { err := node.replica.addPartition(collectionID, partitionID) if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: err.Error(), - } + status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR + status.Reason = err.Error() return status, err } } err := node.replica.enablePartition(partitionID) if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: err.Error(), - } + status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR + status.Reason = err.Error() + return status, err + } + + if len(segmentIDs) == 0 { + return status, nil + } + + if len(in.SegmentIDs) != len(in.SegmentStates) { + err := errors.New("len(segmentIDs) should equal to len(segmentStates)") + status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR + status.Reason = err.Error() return status, err } // segments are ordered before LoadSegments calling + var position *internalpb2.MsgPosition = nil for i, state := range in.SegmentStates { - if state.State == commonpb.SegmentState_SegmentGrowing { - position := state.StartPosition - err := node.loadService.segLoader.seekSegment(position) - if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, - Reason: err.Error(), + thisPosition := state.StartPosition + if state.State <= commonpb.SegmentState_SegmentGrowing { + if position == nil { + position = &internalpb2.MsgPosition{ + ChannelName: thisPosition.ChannelName, } - return status, err } segmentIDs = segmentIDs[:i] break } + position = state.StartPosition } - err = node.loadService.loadSegment(collectionID, partitionID, segmentIDs, fieldIDs) + err = node.dataSyncService.seekSegment(position) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR, @@ -459,9 +468,14 @@ func (node *QueryNode) LoadSegments(in *queryPb.LoadSegmentRequest) (*commonpb.S } return status, err } - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_SUCCESS, - }, nil + + err = node.loadService.loadSegment(collectionID, partitionID, segmentIDs, fieldIDs) + if err != nil { + status.ErrorCode = commonpb.ErrorCode_UNEXPECTED_ERROR + status.Reason = err.Error() + return status, err + } + return status, nil } func (node *QueryNode) ReleaseCollection(in *queryPb.ReleaseCollectionRequest) (*commonpb.Status, error) { diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 2820b96018..c4e3a18bc3 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -28,17 +28,6 @@ type segmentLoader struct { indexLoader *indexLoader } -func (loader *segmentLoader) seekSegment(position *internalpb2.MsgPosition) error { - // TODO: open seek - //for _, position := range positions { - // err := s.dmStream.Seek(position) - // if err != nil { - // return err - // } - //} - return nil -} - func (loader *segmentLoader) getInsertBinlogPaths(segmentID UniqueID) ([]*internalpb2.StringList, []int64, error) { ctx := context.TODO() if loader.dataClient == nil { diff --git a/tests/python/test_load_collection.py b/tests/python/test_load_collection.py index 97f07ba054..e0f2e8b52b 100644 --- a/tests/python/test_load_collection.py +++ b/tests/python/test_load_collection.py @@ -19,4 +19,4 @@ class TestLoadCollection: ids = connect.insert(collection, default_entities) ids = connect.insert(collection, default_entity) connect.flush([collection]) - connect.load_collection(collection) + connect.load_collection(collection) \ No newline at end of file