From deb9963d0e50948d0efa04a1bc266942bb393883 Mon Sep 17 00:00:00 2001 From: SimFG Date: Thu, 3 Nov 2022 21:41:35 +0800 Subject: [PATCH] Delete unused messages for the mq (#20295) Signed-off-by: SimFG Signed-off-by: SimFG --- internal/mq/msgstream/mq_msgstream_test.go | 97 ++--- internal/mq/msgstream/msg.go | 405 --------------------- internal/mq/msgstream/msg_test.go | 299 --------------- internal/mq/msgstream/unmarshal.go | 10 - internal/querynode/mock_test.go | 20 +- 5 files changed, 25 insertions(+), 806 deletions(-) diff --git a/internal/mq/msgstream/mq_msgstream_test.go b/internal/mq/msgstream/mq_msgstream_test.go index 3db61ee5eb..af4dae01d5 100644 --- a/internal/mq/msgstream/mq_msgstream_test.go +++ b/internal/mq/msgstream/mq_msgstream_test.go @@ -439,51 +439,6 @@ func TestStream_PulsarMsgStream_Delete(t *testing.T) { outputStream.Close() } -func TestStream_PulsarMsgStream_Search(t *testing.T) { - pulsarAddress := getPulsarAddress() - c := funcutil.RandomString(8) - producerChannels := []string{c} - consumerChannels := []string{c} - consumerSubName := funcutil.RandomString(8) - - msgPack := MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 1)) - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 3)) - - ctx := context.Background() - inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) - outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - - err := inputStream.Produce(&msgPack) - require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) - - receiveMsg(ctx, outputStream, len(msgPack.Msgs)) - inputStream.Close() - outputStream.Close() -} - -func TestStream_PulsarMsgStream_SearchResult(t *testing.T) { - pulsarAddress := getPulsarAddress() - c := funcutil.RandomString(8) - producerChannels := []string{c} - consumerChannels := []string{c} - consumerSubName := funcutil.RandomString(8) - msgPack := MsgPack{} - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 1)) - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 3)) - - ctx := context.Background() - inputStream := getPulsarInputStream(ctx, pulsarAddress, producerChannels) - outputStream := getPulsarOutputStream(ctx, pulsarAddress, consumerChannels, consumerSubName) - - err := inputStream.Produce(&msgPack) - require.NoErrorf(t, err, fmt.Sprintf("produce error = %v", err)) - - receiveMsg(ctx, outputStream, len(msgPack.Msgs)) - inputStream.Close() - outputStream.Close() -} - func TestStream_PulsarMsgStream_TimeTick(t *testing.T) { pulsarAddress := getPulsarAddress() c := funcutil.RandomString(8) @@ -672,8 +627,8 @@ func TestStream_PulsarMsgStream_DefaultRepackFunc(t *testing.T) { msgPack := MsgPack{} msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_TimeTick, 1)) - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Search, 2)) - msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_SearchResult, 3)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Insert, 2)) + msgPack.Msgs = append(msgPack.Msgs, getTsMsg(commonpb.MsgType_Delete, 3)) factory := ProtoUDFactory{} @@ -1572,8 +1527,8 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { // would not dedup for non-dml messages msgPack2 := MsgPack{} - msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_Search, 2)) - msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_Search, 2)) + msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_CreateCollection, 2)) + msgPack2.Msgs = append(msgPack2.Msgs, getTsMsg(commonpb.MsgType_CreateCollection, 2)) msgPack3 := MsgPack{} msgPack3.Msgs = append(msgPack3.Msgs, getTimeTickMsg(15)) @@ -1608,8 +1563,8 @@ func TestStream_RmqTtMsgStream_DuplicatedIDs(t *testing.T) { seekMsg := consumer(ctx, outputStream) assert.Equal(t, len(seekMsg.Msgs), 1+2) assert.EqualValues(t, seekMsg.Msgs[0].BeginTs(), 1) - assert.Equal(t, commonpb.MsgType_Search, seekMsg.Msgs[1].Type()) - assert.Equal(t, commonpb.MsgType_Search, seekMsg.Msgs[2].Type()) + assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[1].Type()) + assert.Equal(t, commonpb.MsgType_CreateCollection, seekMsg.Msgs[2].Type()) Close(rocksdbName, inputStream, outputStream, etcdKV) } @@ -1958,37 +1913,29 @@ func getTsMsg(msgType MsgType, reqID UniqueID) TsMsg { DeleteRequest: deleteRequest, } return deleteMsg - case commonpb.MsgType_Search: - searchRequest := internalpb.SearchRequest{ + case commonpb.MsgType_CreateCollection: + createCollectionRequest := internalpb.CreateCollectionRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Search, + MsgType: commonpb.MsgType_CreateCollection, MsgID: reqID, Timestamp: 11, SourceID: reqID, }, - ReqID: 0, + DbName: "test_db", + CollectionName: "test_collection", + PartitionName: "test_partition", + DbID: 4, + CollectionID: 5, + PartitionID: 6, + Schema: []byte{}, + VirtualChannelNames: []string{}, + PhysicalChannelNames: []string{}, } - searchMsg := &SearchMsg{ - BaseMsg: baseMsg, - SearchRequest: searchRequest, + createCollectionMsg := &CreateCollectionMsg{ + BaseMsg: baseMsg, + CreateCollectionRequest: createCollectionRequest, } - return searchMsg - case commonpb.MsgType_SearchResult: - searchResult := internalpb.SearchResults{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_SearchResult, - MsgID: reqID, - Timestamp: 1, - SourceID: reqID, - }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - ReqID: 0, - } - searchResultMsg := &SearchResultMsg{ - BaseMsg: baseMsg, - SearchResults: searchResult, - } - return searchResultMsg + return createCollectionMsg case commonpb.MsgType_TimeTick: timeTickResult := internalpb.TimeTickMsg{ Base: &commonpb.MsgBase{ diff --git a/internal/mq/msgstream/msg.go b/internal/mq/msgstream/msg.go index de6fcecefe..b5aa811dfd 100644 --- a/internal/mq/msgstream/msg.go +++ b/internal/mq/msgstream/msg.go @@ -20,7 +20,6 @@ import ( "context" "errors" "fmt" - "time" "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/util/commonpbutil" @@ -32,8 +31,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/util/timerecord" ) // MsgType is an alias of commonpb.MsgType @@ -383,292 +380,6 @@ func (dt *DeleteMsg) CheckAligned() error { return nil } -/////////////////////////////////////////Search////////////////////////////////////////// - -// SearchMsg is a message pack that contains search request -type SearchMsg struct { - BaseMsg - internalpb.SearchRequest - tr *timerecord.TimeRecorder -} - -// interface implementation validation -var _ TsMsg = &SearchMsg{} - -// ID returns the ID of this message pack -func (st *SearchMsg) ID() UniqueID { - return st.Base.MsgID -} - -// Type returns the type of this message pack -func (st *SearchMsg) Type() MsgType { - return st.Base.MsgType -} - -// SourceID indicates which component generated this message -func (st *SearchMsg) SourceID() int64 { - return st.Base.SourceID -} - -// GuaranteeTs returns the guarantee timestamp that querynode can perform this search request. This timestamp -// filled in client(e.g. pymilvus). The timestamp will be 0 if client never execute any insert, otherwise equals -// the timestamp from last insert response. -func (st *SearchMsg) GuaranteeTs() Timestamp { - return st.GetGuaranteeTimestamp() -} - -// TravelTs returns the timestamp of a time travel search request -func (st *SearchMsg) TravelTs() Timestamp { - return st.GetTravelTimestamp() -} - -// TimeoutTs returns the timestamp of timeout -func (st *SearchMsg) TimeoutTs() Timestamp { - return st.GetTimeoutTimestamp() -} - -// SetTimeRecorder sets the timeRecorder for RetrieveMsg -func (st *SearchMsg) SetTimeRecorder() { - st.tr = timerecord.NewTimeRecorder("searchMsg") -} - -// ElapseSpan returns the duration from the beginning -func (st *SearchMsg) ElapseSpan() time.Duration { - return st.tr.ElapseSpan() -} - -// RecordSpan returns the duration from last record -func (st *SearchMsg) RecordSpan() time.Duration { - return st.tr.RecordSpan() -} - -// Marshal is used to serializing a message pack to byte array -func (st *SearchMsg) Marshal(input TsMsg) (MarshalType, error) { - searchTask := input.(*SearchMsg) - searchRequest := &searchTask.SearchRequest - mb, err := proto.Marshal(searchRequest) - if err != nil { - return nil, err - } - return mb, nil -} - -// Unmarshal is used to deserializing a message pack from byte array -func (st *SearchMsg) Unmarshal(input MarshalType) (TsMsg, error) { - searchRequest := internalpb.SearchRequest{} - in, err := convertToByteArray(input) - if err != nil { - return nil, err - } - err = proto.Unmarshal(in, &searchRequest) - if err != nil { - return nil, err - } - searchMsg := &SearchMsg{SearchRequest: searchRequest} - searchMsg.BeginTimestamp = searchMsg.Base.Timestamp - searchMsg.EndTimestamp = searchMsg.Base.Timestamp - - return searchMsg, nil -} - -/////////////////////////////////////////SearchResult////////////////////////////////////////// - -// SearchResultMsg is a message pack that contains the result of search request -type SearchResultMsg struct { - BaseMsg - internalpb.SearchResults -} - -// interface implementation validation -var _ TsMsg = &SearchResultMsg{} - -// ID returns the ID of this message pack -func (srt *SearchResultMsg) ID() UniqueID { - return srt.Base.MsgID -} - -// Type returns the type of this message pack -func (srt *SearchResultMsg) Type() MsgType { - return srt.Base.MsgType -} - -// SourceID indicates which component generated this message -func (srt *SearchResultMsg) SourceID() int64 { - return srt.Base.SourceID -} - -// Marshal is used to serializing a message pack to byte array -func (srt *SearchResultMsg) Marshal(input TsMsg) (MarshalType, error) { - searchResultTask := input.(*SearchResultMsg) - searchResultRequest := &searchResultTask.SearchResults - mb, err := proto.Marshal(searchResultRequest) - if err != nil { - return nil, err - } - return mb, nil -} - -// Unmarshal is used to deserializing a message pack from byte array -func (srt *SearchResultMsg) Unmarshal(input MarshalType) (TsMsg, error) { - searchResultRequest := internalpb.SearchResults{} - in, err := convertToByteArray(input) - if err != nil { - return nil, err - } - err = proto.Unmarshal(in, &searchResultRequest) - if err != nil { - return nil, err - } - searchResultMsg := &SearchResultMsg{SearchResults: searchResultRequest} - searchResultMsg.BeginTimestamp = searchResultMsg.Base.Timestamp - searchResultMsg.EndTimestamp = searchResultMsg.Base.Timestamp - - return searchResultMsg, nil -} - -////////////////////////////////////////Retrieve///////////////////////////////////////// - -// RetrieveMsg is a message pack that contains retrieve request -type RetrieveMsg struct { - BaseMsg - internalpb.RetrieveRequest - tr *timerecord.TimeRecorder -} - -// interface implementation validation -var _ TsMsg = &RetrieveMsg{} - -// ID returns the ID of this message pack -func (rm *RetrieveMsg) ID() UniqueID { - return rm.Base.MsgID -} - -// Type returns the type of this message pack -func (rm *RetrieveMsg) Type() MsgType { - return rm.Base.MsgType -} - -// SourceID indicates which component generated this message -func (rm *RetrieveMsg) SourceID() int64 { - return rm.Base.SourceID -} - -// GuaranteeTs returns the guarantee timestamp that querynode can perform this query request. This timestamp -// filled in client(e.g. pymilvus). The timestamp will be 0 if client never execute any insert, otherwise equals -// the timestamp from last insert response. -func (rm *RetrieveMsg) GuaranteeTs() Timestamp { - return rm.GetGuaranteeTimestamp() -} - -// TravelTs returns the timestamp of a time travel query request -func (rm *RetrieveMsg) TravelTs() Timestamp { - return rm.GetTravelTimestamp() -} - -// TimeoutTs returns the timestamp of timeout -func (rm *RetrieveMsg) TimeoutTs() Timestamp { - return rm.GetTimeoutTimestamp() -} - -// SetTimeRecorder sets the timeRecorder for RetrieveMsg -func (rm *RetrieveMsg) SetTimeRecorder() { - rm.tr = timerecord.NewTimeRecorder("retrieveMsg") -} - -// ElapseSpan returns the duration from the beginning -func (rm *RetrieveMsg) ElapseSpan() time.Duration { - return rm.tr.ElapseSpan() -} - -// RecordSpan returns the duration from last record -func (rm *RetrieveMsg) RecordSpan() time.Duration { - return rm.tr.RecordSpan() -} - -// Marshal is used to serializing a message pack to byte array -func (rm *RetrieveMsg) Marshal(input TsMsg) (MarshalType, error) { - retrieveTask := input.(*RetrieveMsg) - retrieveRequest := &retrieveTask.RetrieveRequest - mb, err := proto.Marshal(retrieveRequest) - if err != nil { - return nil, err - } - return mb, nil -} - -// Unmarshal is used to deserializing a message pack from byte array -func (rm *RetrieveMsg) Unmarshal(input MarshalType) (TsMsg, error) { - retrieveRequest := internalpb.RetrieveRequest{} - in, err := convertToByteArray(input) - if err != nil { - return nil, err - } - err = proto.Unmarshal(in, &retrieveRequest) - if err != nil { - return nil, err - } - retrieveMsg := &RetrieveMsg{RetrieveRequest: retrieveRequest} - retrieveMsg.BeginTimestamp = retrieveMsg.Base.Timestamp - retrieveMsg.EndTimestamp = retrieveMsg.Base.Timestamp - - return retrieveMsg, nil -} - -//////////////////////////////////////RetrieveResult/////////////////////////////////////// - -// RetrieveResultMsg is a message pack that contains the result of query request -type RetrieveResultMsg struct { - BaseMsg - internalpb.RetrieveResults -} - -// interface implementation validation -var _ TsMsg = &RetrieveResultMsg{} - -// ID returns the ID of this message pack -func (rrm *RetrieveResultMsg) ID() UniqueID { - return rrm.Base.MsgID -} - -// Type returns the type of this message pack -func (rrm *RetrieveResultMsg) Type() MsgType { - return rrm.Base.MsgType -} - -// SourceID indicates which component generated this message -func (rrm *RetrieveResultMsg) SourceID() int64 { - return rrm.Base.SourceID -} - -// Marshal is used to serializing a message pack to byte array -func (rrm *RetrieveResultMsg) Marshal(input TsMsg) (MarshalType, error) { - retrieveResultTask := input.(*RetrieveResultMsg) - retrieveResultRequest := &retrieveResultTask.RetrieveResults - mb, err := proto.Marshal(retrieveResultRequest) - if err != nil { - return nil, err - } - return mb, nil -} - -// Unmarshal is used to deserializing a message pack from byte array -func (rrm *RetrieveResultMsg) Unmarshal(input MarshalType) (TsMsg, error) { - retrieveResultRequest := internalpb.RetrieveResults{} - in, err := convertToByteArray(input) - if err != nil { - return nil, err - } - err = proto.Unmarshal(in, &retrieveResultRequest) - if err != nil { - return nil, err - } - retrieveResultMsg := &RetrieveResultMsg{RetrieveResults: retrieveResultRequest} - retrieveResultMsg.BeginTimestamp = retrieveResultMsg.Base.Timestamp - retrieveResultMsg.EndTimestamp = retrieveResultMsg.Base.Timestamp - - return retrieveResultMsg, nil -} - /////////////////////////////////////////TimeTick////////////////////////////////////////// // TimeTickMsg is a message pack that contains time tick only @@ -944,122 +655,6 @@ func (dp *DropPartitionMsg) Unmarshal(input MarshalType) (TsMsg, error) { return dropPartitionMsg, nil } -/////////////////////////////////////////LoadIndex////////////////////////////////////////// -// FIXME(wxyu): comment it until really needed -/* -type LoadIndexMsg struct { - BaseMsg - internalpb.LoadIndex -} - -// TraceCtx returns the context of opentracing -func (lim *LoadIndexMsg) TraceCtx() context.Context { - return lim.BaseMsg.Ctx -} - -// SetTraceCtx is used to set context for opentracing -func (lim *LoadIndexMsg) SetTraceCtx(ctx context.Context) { - lim.BaseMsg.Ctx = ctx -} - -// ID returns the ID of this message pack -func (lim *LoadIndexMsg) ID() UniqueID { - return lim.Base.MsgID -} - -// Type returns the type of this message pack -func (lim *LoadIndexMsg) Type() MsgType { - return lim.Base.MsgType -} - -// SourceID indicated which component generated this message -func (lim *LoadIndexMsg) SourceID() int64 { - return lim.Base.SourceID -} - -// Marshal is used to serializing a message pack to byte array -func (lim *LoadIndexMsg) Marshal(input TsMsg) (MarshalType, error) { - loadIndexMsg := input.(*LoadIndexMsg) - loadIndexRequest := &loadIndexMsg.LoadIndex - mb, err := proto.Marshal(loadIndexRequest) - if err != nil { - return nil, err - } - return mb, nil -} - -// Unmarshal is used to deserializing a message pack from byte array -func (lim *LoadIndexMsg) Unmarshal(input MarshalType) (TsMsg, error) { - loadIndexRequest := internalpb.LoadIndex{} - in, err := convertToByteArray(input) - if err != nil { - return nil, err - } - err = proto.Unmarshal(in, &loadIndexRequest) - if err != nil { - return nil, err - } - loadIndexMsg := &LoadIndexMsg{LoadIndex: loadIndexRequest} - - return loadIndexMsg, nil -} -*/ - -/////////////////////////////////////////SealedSegmentsChangeInfoMsg////////////////////////////////////////// - -// SealedSegmentsChangeInfoMsg is a message pack that contains sealed segments change info -type SealedSegmentsChangeInfoMsg struct { - BaseMsg - querypb.SealedSegmentsChangeInfo -} - -// interface implementation validation -var _ TsMsg = &SealedSegmentsChangeInfoMsg{} - -// ID returns the ID of this message pack -func (s *SealedSegmentsChangeInfoMsg) ID() UniqueID { - return s.Base.MsgID -} - -// Type returns the type of this message pack -func (s *SealedSegmentsChangeInfoMsg) Type() MsgType { - return s.Base.MsgType -} - -// SourceID indicates which component generated this message -func (s *SealedSegmentsChangeInfoMsg) SourceID() int64 { - return s.Base.SourceID -} - -// Marshal is used to serializing a message pack to byte array -func (s *SealedSegmentsChangeInfoMsg) Marshal(input TsMsg) (MarshalType, error) { - changeInfoMsg := input.(*SealedSegmentsChangeInfoMsg) - changeInfo := &changeInfoMsg.SealedSegmentsChangeInfo - mb, err := proto.Marshal(changeInfo) - if err != nil { - return nil, err - } - return mb, nil -} - -// Unmarshal is used to deserializing a message pack from byte array -func (s *SealedSegmentsChangeInfoMsg) Unmarshal(input MarshalType) (TsMsg, error) { - changeInfo := querypb.SealedSegmentsChangeInfo{} - in, err := convertToByteArray(input) - if err != nil { - return nil, err - } - err = proto.Unmarshal(in, &changeInfo) - if err != nil { - return nil, err - } - changeInfoMsg := &SealedSegmentsChangeInfoMsg{SealedSegmentsChangeInfo: changeInfo} - changeInfoMsg.BeginTimestamp = changeInfo.Base.Timestamp - changeInfoMsg.EndTimestamp = changeInfo.Base.Timestamp - - return changeInfoMsg, nil -} - /////////////////////////////////////////DataNodeTtMsg////////////////////////////////////////// // DataNodeTtMsg is a message pack that contains datanode time tick diff --git a/internal/mq/msgstream/msg_test.go b/internal/mq/msgstream/msg_test.go index 026a501055..a00bbf4750 100644 --- a/internal/mq/msgstream/msg_test.go +++ b/internal/mq/msgstream/msg_test.go @@ -26,7 +26,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/querypb" ) func TestBaseMsg(t *testing.T) { @@ -320,240 +319,6 @@ func TestDeleteMsg_Unmarshal_IllegalParameter(t *testing.T) { assert.Nil(t, tsMsg) } -func TestSearchMsg(t *testing.T) { - searchMsg := &SearchMsg{ - BaseMsg: generateBaseMsg(), - SearchRequest: internalpb.SearchRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Search, - MsgID: 1, - Timestamp: 2, - SourceID: 3, - }, - DbID: 4, - CollectionID: 5, - PartitionIDs: []int64{}, - Dsl: "dsl", - PlaceholderGroup: []byte{}, - DslType: commonpb.DslType_BoolExprV1, - SerializedExprPlan: []byte{}, - OutputFieldsId: []int64{}, - TravelTimestamp: 6, - GuaranteeTimestamp: 7, - TimeoutTimestamp: 8, - }, - } - - assert.NotNil(t, searchMsg.TraceCtx()) - - ctx := context.Background() - searchMsg.SetTraceCtx(ctx) - assert.Equal(t, ctx, searchMsg.TraceCtx()) - - assert.Equal(t, int64(1), searchMsg.ID()) - assert.Equal(t, commonpb.MsgType_Search, searchMsg.Type()) - assert.Equal(t, int64(3), searchMsg.SourceID()) - assert.Equal(t, uint64(7), searchMsg.GuaranteeTs()) - assert.Equal(t, uint64(6), searchMsg.TravelTs()) - assert.Equal(t, uint64(8), searchMsg.TimeoutTs()) - - bytes, err := searchMsg.Marshal(searchMsg) - assert.Nil(t, err) - - tsMsg, err := searchMsg.Unmarshal(bytes) - assert.Nil(t, err) - - searchMsg2, ok := tsMsg.(*SearchMsg) - assert.True(t, ok) - assert.Equal(t, int64(1), searchMsg2.ID()) - assert.Equal(t, commonpb.MsgType_Search, searchMsg2.Type()) - assert.Equal(t, int64(3), searchMsg2.SourceID()) - assert.Equal(t, uint64(7), searchMsg2.GuaranteeTs()) - assert.Equal(t, uint64(6), searchMsg2.TravelTs()) -} - -func TestSearchMsg_Unmarshal_IllegalParameter(t *testing.T) { - searchMsg := &SearchMsg{} - tsMsg, err := searchMsg.Unmarshal(10) - assert.NotNil(t, err) - assert.Nil(t, tsMsg) -} - -func TestSearchResultMsg(t *testing.T) { - searchResultMsg := &SearchResultMsg{ - BaseMsg: generateBaseMsg(), - SearchResults: internalpb.SearchResults{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_SearchResult, - MsgID: 1, - Timestamp: 2, - SourceID: 3, - }, - MetricType: "l2", - NumQueries: 5, - TopK: 6, - SealedSegmentIDsSearched: []int64{7}, - ChannelIDsSearched: []string{"test-searched"}, - GlobalSealedSegmentIDs: []int64{8}, - }, - } - - assert.NotNil(t, searchResultMsg.TraceCtx()) - - ctx := context.Background() - searchResultMsg.SetTraceCtx(ctx) - assert.Equal(t, ctx, searchResultMsg.TraceCtx()) - - assert.Equal(t, int64(1), searchResultMsg.ID()) - assert.Equal(t, commonpb.MsgType_SearchResult, searchResultMsg.Type()) - assert.Equal(t, int64(3), searchResultMsg.SourceID()) - - bytes, err := searchResultMsg.Marshal(searchResultMsg) - assert.Nil(t, err) - - tsMsg, err := searchResultMsg.Unmarshal(bytes) - assert.Nil(t, err) - - searchResultMsg2, ok := tsMsg.(*SearchResultMsg) - assert.True(t, ok) - assert.Equal(t, int64(1), searchResultMsg2.ID()) - assert.Equal(t, commonpb.MsgType_SearchResult, searchResultMsg2.Type()) - assert.Equal(t, int64(3), searchResultMsg2.SourceID()) -} - -func TestSearchResultMsg_Unmarshal_IllegalParameter(t *testing.T) { - searchResultMsg := &SearchResultMsg{} - tsMsg, err := searchResultMsg.Unmarshal(10) - assert.NotNil(t, err) - assert.Nil(t, tsMsg) -} - -func TestRetrieveMsg(t *testing.T) { - retrieveMsg := &RetrieveMsg{ - BaseMsg: generateBaseMsg(), - RetrieveRequest: internalpb.RetrieveRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Retrieve, - MsgID: 1, - Timestamp: 2, - SourceID: 3, - }, - DbID: 5, - CollectionID: 6, - PartitionIDs: []int64{7, 8}, - SerializedExprPlan: []byte{}, - OutputFieldsId: []int64{8, 9}, - TravelTimestamp: 10, - GuaranteeTimestamp: 11, - TimeoutTimestamp: 12, - }, - } - - assert.NotNil(t, retrieveMsg.TraceCtx()) - - ctx := context.Background() - retrieveMsg.SetTraceCtx(ctx) - assert.Equal(t, ctx, retrieveMsg.TraceCtx()) - - assert.Equal(t, int64(1), retrieveMsg.ID()) - assert.Equal(t, commonpb.MsgType_Retrieve, retrieveMsg.Type()) - assert.Equal(t, int64(3), retrieveMsg.SourceID()) - assert.Equal(t, uint64(11), retrieveMsg.GuaranteeTs()) - assert.Equal(t, uint64(10), retrieveMsg.TravelTs()) - assert.Equal(t, uint64(12), retrieveMsg.TimeoutTs()) - - bytes, err := retrieveMsg.Marshal(retrieveMsg) - assert.Nil(t, err) - - tsMsg, err := retrieveMsg.Unmarshal(bytes) - assert.Nil(t, err) - - retrieveMsg2, ok := tsMsg.(*RetrieveMsg) - assert.True(t, ok) - assert.Equal(t, int64(1), retrieveMsg2.ID()) - assert.Equal(t, commonpb.MsgType_Retrieve, retrieveMsg2.Type()) - assert.Equal(t, int64(3), retrieveMsg2.SourceID()) - assert.Equal(t, uint64(11), retrieveMsg2.GuaranteeTs()) - assert.Equal(t, uint64(10), retrieveMsg2.TravelTs()) -} - -func TestRetrieveMsg_Unmarshal_IllegalParameter(t *testing.T) { - retrieveMsg := &RetrieveMsg{} - tsMsg, err := retrieveMsg.Unmarshal(10) - assert.NotNil(t, err) - assert.Nil(t, tsMsg) -} - -func TestRetrieveResultMsg(t *testing.T) { - retrieveResultMsg := &RetrieveResultMsg{ - BaseMsg: generateBaseMsg(), - RetrieveResults: internalpb.RetrieveResults{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_RetrieveResult, - MsgID: 1, - Timestamp: 2, - SourceID: 3, - }, - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{}, - }, - }, - }, - FieldsData: []*schemapb.FieldData{ - { - Type: schemapb.DataType_FloatVector, - FieldName: "vector_field", - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: 4, - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: []float32{1.1, 2.2, 3.3, 4.4}, - }, - }, - }, - }, - FieldId: 5, - }, - }, - SealedSegmentIDsRetrieved: []int64{6, 7}, - ChannelIDsRetrieved: []string{"test-retrieved-channel"}, - GlobalSealedSegmentIDs: []int64{8, 9}, - }, - } - - assert.NotNil(t, retrieveResultMsg.TraceCtx()) - - ctx := context.Background() - retrieveResultMsg.SetTraceCtx(ctx) - assert.Equal(t, ctx, retrieveResultMsg.TraceCtx()) - - assert.Equal(t, int64(1), retrieveResultMsg.ID()) - assert.Equal(t, commonpb.MsgType_RetrieveResult, retrieveResultMsg.Type()) - assert.Equal(t, int64(3), retrieveResultMsg.SourceID()) - - bytes, err := retrieveResultMsg.Marshal(retrieveResultMsg) - assert.Nil(t, err) - - tsMsg, err := retrieveResultMsg.Unmarshal(bytes) - assert.Nil(t, err) - - retrieveResultMsg2, ok := tsMsg.(*RetrieveResultMsg) - assert.True(t, ok) - assert.Equal(t, int64(1), retrieveResultMsg2.ID()) - assert.Equal(t, commonpb.MsgType_RetrieveResult, retrieveResultMsg2.Type()) - assert.Equal(t, int64(3), retrieveResultMsg2.SourceID()) -} - -func TestRetrieveResultMsg_Unmarshal_IllegalParameter(t *testing.T) { - retrieveResultMsg := &RetrieveResultMsg{} - tsMsg, err := retrieveResultMsg.Unmarshal(10) - assert.NotNil(t, err) - assert.Nil(t, tsMsg) -} - func TestTimeTickMsg(t *testing.T) { timeTickMsg := &TimeTickMsg{ BaseMsg: generateBaseMsg(), @@ -838,67 +603,3 @@ func TestDataNodeTtMsg_Unmarshal_IllegalParameter(t *testing.T) { assert.NotNil(t, err) assert.Nil(t, tsMsg) } - -func TestSealedSegmentsChangeInfoMsg(t *testing.T) { - genSimpleSegmentInfo := func(segmentID UniqueID) *querypb.SegmentInfo { - return &querypb.SegmentInfo{ - SegmentID: segmentID, - } - } - - changeInfo := &querypb.SegmentChangeInfo{ - OnlineNodeID: int64(1), - OnlineSegments: []*querypb.SegmentInfo{ - genSimpleSegmentInfo(1), - genSimpleSegmentInfo(2), - genSimpleSegmentInfo(3), - }, - OfflineNodeID: int64(2), - OfflineSegments: []*querypb.SegmentInfo{ - genSimpleSegmentInfo(4), - genSimpleSegmentInfo(5), - genSimpleSegmentInfo(6), - }, - } - changeInfoMsg := &SealedSegmentsChangeInfoMsg{ - BaseMsg: generateBaseMsg(), - SealedSegmentsChangeInfo: querypb.SealedSegmentsChangeInfo{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_SealedSegmentsChangeInfo, - MsgID: 1, - Timestamp: 2, - SourceID: 3, - }, - Infos: []*querypb.SegmentChangeInfo{changeInfo}, - }, - } - - assert.NotNil(t, changeInfoMsg.TraceCtx()) - - ctx := context.Background() - changeInfoMsg.SetTraceCtx(ctx) - assert.Equal(t, ctx, changeInfoMsg.TraceCtx()) - - assert.Equal(t, int64(1), changeInfoMsg.ID()) - assert.Equal(t, commonpb.MsgType_SealedSegmentsChangeInfo, changeInfoMsg.Type()) - assert.Equal(t, int64(3), changeInfoMsg.SourceID()) - - bytes, err := changeInfoMsg.Marshal(changeInfoMsg) - assert.Nil(t, err) - - tsMsg, err := changeInfoMsg.Unmarshal(bytes) - assert.Nil(t, err) - - changeInfoMsg2, ok := tsMsg.(*SealedSegmentsChangeInfoMsg) - assert.True(t, ok) - assert.Equal(t, int64(1), changeInfoMsg2.ID()) - assert.Equal(t, commonpb.MsgType_SealedSegmentsChangeInfo, changeInfoMsg2.Type()) - assert.Equal(t, int64(3), changeInfoMsg2.SourceID()) -} - -func TestSealedSegmentsChangeInfoMsg_Unmarshal_IllegalParameter(t *testing.T) { - changeInfoMsg := &SealedSegmentsChangeInfoMsg{} - tsMsg, err := changeInfoMsg.Unmarshal(10) - assert.NotNil(t, err) - assert.Nil(t, tsMsg) -} diff --git a/internal/mq/msgstream/unmarshal.go b/internal/mq/msgstream/unmarshal.go index b132894afe..a58c7b9aa7 100644 --- a/internal/mq/msgstream/unmarshal.go +++ b/internal/mq/msgstream/unmarshal.go @@ -56,33 +56,23 @@ type ProtoUDFactory struct{} func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { insertMsg := InsertMsg{} deleteMsg := DeleteMsg{} - searchMsg := SearchMsg{} - searchResultMsg := SearchResultMsg{} - retrieveMsg := RetrieveMsg{} - retrieveResultMsg := RetrieveResultMsg{} timeTickMsg := TimeTickMsg{} createCollectionMsg := CreateCollectionMsg{} dropCollectionMsg := DropCollectionMsg{} createPartitionMsg := CreatePartitionMsg{} dropPartitionMsg := DropPartitionMsg{} dataNodeTtMsg := DataNodeTtMsg{} - sealedSegmentsChangeInfoMsg := SealedSegmentsChangeInfoMsg{} p := &ProtoUnmarshalDispatcher{} p.TempMap = make(map[commonpb.MsgType]UnmarshalFunc) p.TempMap[commonpb.MsgType_Insert] = insertMsg.Unmarshal p.TempMap[commonpb.MsgType_Delete] = deleteMsg.Unmarshal - p.TempMap[commonpb.MsgType_Search] = searchMsg.Unmarshal - p.TempMap[commonpb.MsgType_SearchResult] = searchResultMsg.Unmarshal - p.TempMap[commonpb.MsgType_Retrieve] = retrieveMsg.Unmarshal - p.TempMap[commonpb.MsgType_RetrieveResult] = retrieveResultMsg.Unmarshal p.TempMap[commonpb.MsgType_TimeTick] = timeTickMsg.Unmarshal p.TempMap[commonpb.MsgType_CreateCollection] = createCollectionMsg.Unmarshal p.TempMap[commonpb.MsgType_DropCollection] = dropCollectionMsg.Unmarshal p.TempMap[commonpb.MsgType_CreatePartition] = createPartitionMsg.Unmarshal p.TempMap[commonpb.MsgType_DropPartition] = dropPartitionMsg.Unmarshal p.TempMap[commonpb.MsgType_DataNodeTt] = dataNodeTtMsg.Unmarshal - p.TempMap[commonpb.MsgType_SealedSegmentsChangeInfo] = sealedSegmentsChangeInfoMsg.Unmarshal return p } diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 64a1475810..9c2abf41f5 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -1480,13 +1480,13 @@ func genSimpleRetrievePlanExpr(schema *schemapb.CollectionSchema) ([]byte, error } func genSimpleRetrievePlan(collection *Collection) (*RetrievePlan, error) { - retrieveMsg, err := genRetrieveMsg(collection.schema) + timestamp := Timestamp(1000) + planBytes, err := genSimpleRetrievePlanExpr(collection.schema) if err != nil { return nil, err } - timestamp := retrieveMsg.RetrieveRequest.TravelTimestamp - plan, err2 := createRetrievePlanByExpr(collection, retrieveMsg.SerializedExprPlan, timestamp, 100) + plan, err2 := createRetrievePlanByExpr(collection, planBytes, timestamp, 100) return plan, err2 } @@ -1546,20 +1546,6 @@ func genRetrieveRequest(schema *schemapb.CollectionSchema) (*internalpb.Retrieve }, nil } -func genRetrieveMsg(schema *schemapb.CollectionSchema) (*msgstream.RetrieveMsg, error) { - req, err := genRetrieveRequest(schema) - if err != nil { - return nil, err - } - - msg := &msgstream.RetrieveMsg{ - BaseMsg: genMsgStreamBaseMsg(), - RetrieveRequest: *req, - } - msg.SetTimeRecorder() - return msg, nil -} - func genQueryResultChannel() Channel { const queryResultChannelPrefix = "query-node-unittest-query-result-channel-" return queryResultChannelPrefix + strconv.Itoa(rand.Int())