merge retrive and search code in query node (#6227)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2021-06-30 17:50:15 +08:00 committed by GitHub
parent f146d3825f
commit ff93d1611f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 267 additions and 345 deletions

View File

@ -32,6 +32,7 @@ type TsMsg interface {
BeginTs() Timestamp BeginTs() Timestamp
EndTs() Timestamp EndTs() Timestamp
Type() MsgType Type() MsgType
SourceID() int64
HashKeys() []uint32 HashKeys() []uint32
Marshal(TsMsg) (MarshalType, error) Marshal(TsMsg) (MarshalType, error)
Unmarshal(MarshalType) (TsMsg, error) Unmarshal(MarshalType) (TsMsg, error)
@ -97,6 +98,10 @@ func (it *InsertMsg) Type() MsgType {
return it.Base.MsgType return it.Base.MsgType
} }
func (it *InsertMsg) SourceID() int64 {
return it.Base.SourceID
}
func (it *InsertMsg) Marshal(input TsMsg) (MarshalType, error) { func (it *InsertMsg) Marshal(input TsMsg) (MarshalType, error) {
insertMsg := input.(*InsertMsg) insertMsg := input.(*InsertMsg)
insertRequest := &insertMsg.InsertRequest insertRequest := &insertMsg.InsertRequest
@ -157,6 +162,10 @@ func (fl *FlushCompletedMsg) Type() MsgType {
return fl.Base.MsgType return fl.Base.MsgType
} }
func (fl *FlushCompletedMsg) SourceID() int64 {
return fl.Base.SourceID
}
func (fl *FlushCompletedMsg) Marshal(input TsMsg) (MarshalType, error) { func (fl *FlushCompletedMsg) Marshal(input TsMsg) (MarshalType, error) {
flushCompletedMsgTask := input.(*FlushCompletedMsg) flushCompletedMsgTask := input.(*FlushCompletedMsg)
flushCompletedMsg := &flushCompletedMsgTask.SegmentFlushCompletedMsg flushCompletedMsg := &flushCompletedMsgTask.SegmentFlushCompletedMsg
@ -206,6 +215,10 @@ func (dt *DeleteMsg) Type() MsgType {
return dt.Base.MsgType return dt.Base.MsgType
} }
func (dt *DeleteMsg) SourceID() int64 {
return dt.Base.SourceID
}
func (dt *DeleteMsg) Marshal(input TsMsg) (MarshalType, error) { func (dt *DeleteMsg) Marshal(input TsMsg) (MarshalType, error) {
deleteMsg := input.(*DeleteMsg) deleteMsg := input.(*DeleteMsg)
deleteRequest := &deleteMsg.DeleteRequest deleteRequest := &deleteMsg.DeleteRequest
@ -267,6 +280,10 @@ func (st *SearchMsg) Type() MsgType {
return st.Base.MsgType return st.Base.MsgType
} }
func (st *SearchMsg) SourceID() int64 {
return st.Base.SourceID
}
func (st *SearchMsg) Marshal(input TsMsg) (MarshalType, error) { func (st *SearchMsg) Marshal(input TsMsg) (MarshalType, error) {
searchTask := input.(*SearchMsg) searchTask := input.(*SearchMsg)
searchRequest := &searchTask.SearchRequest searchRequest := &searchTask.SearchRequest
@ -316,6 +333,10 @@ func (srt *SearchResultMsg) Type() MsgType {
return srt.Base.MsgType return srt.Base.MsgType
} }
func (srt *SearchResultMsg) SourceID() int64 {
return srt.Base.SourceID
}
func (srt *SearchResultMsg) Marshal(input TsMsg) (MarshalType, error) { func (srt *SearchResultMsg) Marshal(input TsMsg) (MarshalType, error) {
searchResultTask := input.(*SearchResultMsg) searchResultTask := input.(*SearchResultMsg)
searchResultRequest := &searchResultTask.SearchResults searchResultRequest := &searchResultTask.SearchResults
@ -365,6 +386,10 @@ func (rm *RetrieveMsg) Type() MsgType {
return rm.Base.MsgType return rm.Base.MsgType
} }
func (rm *RetrieveMsg) SourceID() int64 {
return rm.Base.SourceID
}
func (rm *RetrieveMsg) Marshal(input TsMsg) (MarshalType, error) { func (rm *RetrieveMsg) Marshal(input TsMsg) (MarshalType, error) {
retrieveTask := input.(*RetrieveMsg) retrieveTask := input.(*RetrieveMsg)
retrieveRequest := &retrieveTask.RetrieveRequest retrieveRequest := &retrieveTask.RetrieveRequest
@ -414,6 +439,10 @@ func (rrm *RetrieveResultMsg) Type() MsgType {
return rrm.Base.MsgType return rrm.Base.MsgType
} }
func (rrm *RetrieveResultMsg) SourceID() int64 {
return rrm.Base.SourceID
}
func (rrm *RetrieveResultMsg) Marshal(input TsMsg) (MarshalType, error) { func (rrm *RetrieveResultMsg) Marshal(input TsMsg) (MarshalType, error) {
retrieveResultTask := input.(*RetrieveResultMsg) retrieveResultTask := input.(*RetrieveResultMsg)
retrieveResultRequest := &retrieveResultTask.RetrieveResults retrieveResultRequest := &retrieveResultTask.RetrieveResults
@ -463,6 +492,10 @@ func (tst *TimeTickMsg) Type() MsgType {
return tst.Base.MsgType return tst.Base.MsgType
} }
func (tst *TimeTickMsg) SourceID() int64 {
return tst.Base.SourceID
}
func (tst *TimeTickMsg) Marshal(input TsMsg) (MarshalType, error) { func (tst *TimeTickMsg) Marshal(input TsMsg) (MarshalType, error) {
timeTickTask := input.(*TimeTickMsg) timeTickTask := input.(*TimeTickMsg)
timeTick := &timeTickTask.TimeTickMsg timeTick := &timeTickTask.TimeTickMsg
@ -513,6 +546,10 @@ func (qs *QueryNodeStatsMsg) Type() MsgType {
return qs.Base.MsgType return qs.Base.MsgType
} }
func (qs *QueryNodeStatsMsg) SourceID() int64 {
return qs.Base.SourceID
}
func (qs *QueryNodeStatsMsg) Marshal(input TsMsg) (MarshalType, error) { func (qs *QueryNodeStatsMsg) Marshal(input TsMsg) (MarshalType, error) {
queryNodeSegStatsTask := input.(*QueryNodeStatsMsg) queryNodeSegStatsTask := input.(*QueryNodeStatsMsg)
queryNodeSegStats := &queryNodeSegStatsTask.QueryNodeStats queryNodeSegStats := &queryNodeSegStatsTask.QueryNodeStats
@ -560,6 +597,10 @@ func (ss *SegmentStatisticsMsg) Type() MsgType {
return ss.Base.MsgType return ss.Base.MsgType
} }
func (ss *SegmentStatisticsMsg) SourceID() int64 {
return ss.Base.SourceID
}
func (ss *SegmentStatisticsMsg) Marshal(input TsMsg) (MarshalType, error) { func (ss *SegmentStatisticsMsg) Marshal(input TsMsg) (MarshalType, error) {
segStatsTask := input.(*SegmentStatisticsMsg) segStatsTask := input.(*SegmentStatisticsMsg)
segStats := &segStatsTask.SegmentStatistics segStats := &segStatsTask.SegmentStatistics
@ -607,6 +648,10 @@ func (cc *CreateCollectionMsg) Type() MsgType {
return cc.Base.MsgType return cc.Base.MsgType
} }
func (cc *CreateCollectionMsg) SourceID() int64 {
return cc.Base.SourceID
}
func (cc *CreateCollectionMsg) Marshal(input TsMsg) (MarshalType, error) { func (cc *CreateCollectionMsg) Marshal(input TsMsg) (MarshalType, error) {
createCollectionMsg := input.(*CreateCollectionMsg) createCollectionMsg := input.(*CreateCollectionMsg)
createCollectionRequest := &createCollectionMsg.CreateCollectionRequest createCollectionRequest := &createCollectionMsg.CreateCollectionRequest
@ -656,6 +701,10 @@ func (dc *DropCollectionMsg) Type() MsgType {
return dc.Base.MsgType return dc.Base.MsgType
} }
func (dc *DropCollectionMsg) SourceID() int64 {
return dc.Base.SourceID
}
func (dc *DropCollectionMsg) Marshal(input TsMsg) (MarshalType, error) { func (dc *DropCollectionMsg) Marshal(input TsMsg) (MarshalType, error) {
dropCollectionMsg := input.(*DropCollectionMsg) dropCollectionMsg := input.(*DropCollectionMsg)
dropCollectionRequest := &dropCollectionMsg.DropCollectionRequest dropCollectionRequest := &dropCollectionMsg.DropCollectionRequest
@ -705,6 +754,10 @@ func (cp *CreatePartitionMsg) Type() MsgType {
return cp.Base.MsgType return cp.Base.MsgType
} }
func (cp *CreatePartitionMsg) SourceID() int64 {
return cp.Base.SourceID
}
func (cp *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) { func (cp *CreatePartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
createPartitionMsg := input.(*CreatePartitionMsg) createPartitionMsg := input.(*CreatePartitionMsg)
createPartitionRequest := &createPartitionMsg.CreatePartitionRequest createPartitionRequest := &createPartitionMsg.CreatePartitionRequest
@ -754,6 +807,10 @@ func (dp *DropPartitionMsg) Type() MsgType {
return dp.Base.MsgType return dp.Base.MsgType
} }
func (dp *DropPartitionMsg) SourceID() int64 {
return dp.Base.SourceID
}
func (dp *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) { func (dp *DropPartitionMsg) Marshal(input TsMsg) (MarshalType, error) {
dropPartitionMsg := input.(*DropPartitionMsg) dropPartitionMsg := input.(*DropPartitionMsg)
dropPartitionRequest := &dropPartitionMsg.DropPartitionRequest dropPartitionRequest := &dropPartitionMsg.DropPartitionRequest
@ -803,6 +860,10 @@ func (lim *LoadIndexMsg) Type() MsgType {
return lim.Base.MsgType return lim.Base.MsgType
} }
func (lim *LoadIndexMsg) SourceID() int64 {
return lim.Base.SourceID
}
func (lim *LoadIndexMsg) Marshal(input TsMsg) (MarshalType, error) { func (lim *LoadIndexMsg) Marshal(input TsMsg) (MarshalType, error) {
loadIndexMsg := input.(*LoadIndexMsg) loadIndexMsg := input.(*LoadIndexMsg)
loadIndexRequest := &loadIndexMsg.LoadIndex loadIndexRequest := &loadIndexMsg.LoadIndex
@ -850,6 +911,10 @@ func (sim *SegmentInfoMsg) Type() MsgType {
return sim.Base.MsgType return sim.Base.MsgType
} }
func (sim *SegmentInfoMsg) SourceID() int64 {
return sim.Base.SourceID
}
func (sim *SegmentInfoMsg) Marshal(input TsMsg) (MarshalType, error) { func (sim *SegmentInfoMsg) Marshal(input TsMsg) (MarshalType, error) {
segInfoMsg := input.(*SegmentInfoMsg) segInfoMsg := input.(*SegmentInfoMsg)
mb, err := proto.Marshal(&segInfoMsg.SegmentMsg) mb, err := proto.Marshal(&segInfoMsg.SegmentMsg)
@ -896,6 +961,10 @@ func (l *LoadBalanceSegmentsMsg) Type() MsgType {
return l.Base.MsgType return l.Base.MsgType
} }
func (l *LoadBalanceSegmentsMsg) SourceID() int64 {
return l.Base.SourceID
}
func (l *LoadBalanceSegmentsMsg) Marshal(input TsMsg) (MarshalType, error) { func (l *LoadBalanceSegmentsMsg) Marshal(input TsMsg) (MarshalType, error) {
load := input.(*LoadBalanceSegmentsMsg) load := input.(*LoadBalanceSegmentsMsg)
loadReq := &load.LoadBalanceSegmentsRequest loadReq := &load.LoadBalanceSegmentsRequest
@ -944,6 +1013,10 @@ func (m *DataNodeTtMsg) Type() MsgType {
return m.Base.MsgType return m.Base.MsgType
} }
func (m *DataNodeTtMsg) SourceID() int64 {
return m.Base.SourceID
}
func (m *DataNodeTtMsg) Marshal(input TsMsg) (MarshalType, error) { func (m *DataNodeTtMsg) Marshal(input TsMsg) (MarshalType, error) {
msg := input.(*DataNodeTtMsg) msg := input.(*DataNodeTtMsg)
t, err := proto.Marshal(&msg.DataNodeTtMsg) t, err := proto.Marshal(&msg.DataNodeTtMsg)

View File

@ -14,7 +14,6 @@ package querynode
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"math" "math"
"reflect" "reflect"
@ -45,8 +44,7 @@ type queryCollection struct {
streaming *streaming streaming *streaming
unsolvedMsgMu sync.Mutex // guards unsolvedMsg unsolvedMsgMu sync.Mutex // guards unsolvedMsg
unsolvedMsg []*msgstream.SearchMsg unsolvedMsg []msgstream.TsMsg
unsolvedRetrieveMsg []*msgstream.RetrieveMsg
tSafeWatchers map[Channel]*tSafeWatcher tSafeWatchers map[Channel]*tSafeWatcher
watcherSelectCase []reflect.SelectCase watcherSelectCase []reflect.SelectCase
@ -67,7 +65,7 @@ func newQueryCollection(releaseCtx context.Context,
streaming *streaming, streaming *streaming,
factory msgstream.Factory) *queryCollection { factory msgstream.Factory) *queryCollection {
unsolvedMsg := make([]*msgstream.SearchMsg, 0) unsolvedMsg := make([]msgstream.TsMsg, 0)
queryStream, _ := factory.NewQueryMsgStream(releaseCtx) queryStream, _ := factory.NewQueryMsgStream(releaseCtx)
queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx) queryResultStream, _ := factory.NewQueryMsgStream(releaseCtx)
@ -96,8 +94,7 @@ func (q *queryCollection) start() {
go q.queryMsgStream.Start() go q.queryMsgStream.Start()
go q.queryResultMsgStream.Start() go q.queryResultMsgStream.Start()
go q.consumeQuery() go q.consumeQuery()
go q.doUnsolvedMsgSearch() go q.doUnsolvedQueryMsg()
go q.doUnsolvedMsgRetrieve()
} }
func (q *queryCollection) close() { func (q *queryCollection) close() {
@ -132,19 +129,13 @@ func (q *queryCollection) register() {
} }
} }
func (q *queryCollection) addToUnsolvedMsg(msg *msgstream.SearchMsg) { func (q *queryCollection) addToUnsolvedMsg(msg msgstream.TsMsg) {
q.unsolvedMsgMu.Lock() q.unsolvedMsgMu.Lock()
defer q.unsolvedMsgMu.Unlock() defer q.unsolvedMsgMu.Unlock()
q.unsolvedMsg = append(q.unsolvedMsg, msg) q.unsolvedMsg = append(q.unsolvedMsg, msg)
} }
func (q *queryCollection) addToUnsolvedRetrieveMsg(msg *msgstream.RetrieveMsg) { func (q *queryCollection) popAllUnsolvedMsg() []msgstream.TsMsg {
q.unsolvedMsgMu.Lock()
defer q.unsolvedMsgMu.Unlock()
q.unsolvedRetrieveMsg = append(q.unsolvedRetrieveMsg, msg)
}
func (q *queryCollection) popAllUnsolvedMsg() []*msgstream.SearchMsg {
q.unsolvedMsgMu.Lock() q.unsolvedMsgMu.Lock()
defer q.unsolvedMsgMu.Unlock() defer q.unsolvedMsgMu.Unlock()
tmp := q.unsolvedMsg tmp := q.unsolvedMsg
@ -152,14 +143,6 @@ func (q *queryCollection) popAllUnsolvedMsg() []*msgstream.SearchMsg {
return tmp return tmp
} }
func (q *queryCollection) popAllUnsolvedRetrieveMsg() []*msgstream.RetrieveMsg {
q.unsolvedMsgMu.Lock()
defer q.unsolvedMsgMu.Unlock()
tmp := q.unsolvedRetrieveMsg
q.unsolvedRetrieveMsg = q.unsolvedRetrieveMsg[:0]
return tmp
}
func (q *queryCollection) waitNewTSafe() Timestamp { func (q *queryCollection) waitNewTSafe() Timestamp {
// block until any vChannel updating tSafe // block until any vChannel updating tSafe
_, _, recvOK := reflect.Select(q.watcherSelectCase) _, _, recvOK := reflect.Select(q.watcherSelectCase)
@ -201,17 +184,6 @@ func (q *queryCollection) setServiceableTime(t Timestamp) {
} }
} }
func (q *queryCollection) emptySearch(searchMsg *msgstream.SearchMsg) {
sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
defer sp.Finish()
searchMsg.SetTraceCtx(ctx)
err := q.search(searchMsg)
if err != nil {
log.Error(err.Error())
q.publishFailedSearchResult(searchMsg, err.Error())
}
}
func (q *queryCollection) consumeQuery() { func (q *queryCollection) consumeQuery() {
for { for {
select { select {
@ -233,11 +205,11 @@ func (q *queryCollection) consumeQuery() {
for _, msg := range msgPack.Msgs { for _, msg := range msgPack.Msgs {
switch sm := msg.(type) { switch sm := msg.(type) {
case *msgstream.SearchMsg: case *msgstream.SearchMsg:
q.receiveSearch(sm) q.receiveQueryMsg(sm)
case *msgstream.LoadBalanceSegmentsMsg: case *msgstream.LoadBalanceSegmentsMsg:
q.loadBalance(sm) q.loadBalance(sm)
case *msgstream.RetrieveMsg: case *msgstream.RetrieveMsg:
q.receiveRetrieve(sm) q.receiveQueryMsg(sm)
default: default:
log.Warn("unsupported msg type in search channel", zap.Any("msg", sm)) log.Warn("unsupported msg type in search channel", zap.Any("msg", sm))
} }
@ -277,123 +249,86 @@ func (q *queryCollection) loadBalance(msg *msgstream.LoadBalanceSegmentsMsg) {
// zap.Int("num of segment", len(msg.Infos))) // zap.Int("num of segment", len(msg.Infos)))
} }
func (q *queryCollection) receiveRetrieve(msg *msgstream.RetrieveMsg) { func (q *queryCollection) receiveQueryMsg(msg msgstream.TsMsg) {
if msg.CollectionID != q.collectionID { msgType := msg.Type()
log.Debug("not target collection retrieve request", var collectionID UniqueID
zap.Any("collectionID", msg.CollectionID), var msgTypeStr string
zap.Int64("msgID", msg.ID()),
)
return
}
switch msgType {
case commonpb.MsgType_Retrieve:
collectionID = msg.(*msgstream.RetrieveMsg).CollectionID
msgTypeStr = "retrieve"
log.Debug("consume retrieve message", log.Debug("consume retrieve message",
zap.Any("collectionID", msg.CollectionID), zap.Any("collectionID", collectionID),
zap.Int64("msgID", msg.ID()), zap.Int64("msgID", msg.ID()),
) )
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) case commonpb.MsgType_Search:
msg.SetTraceCtx(ctx) collectionID = msg.(*msgstream.SearchMsg).CollectionID
msgTypeStr = "search"
// check if collection has been released
collection, err := q.historical.replica.getCollectionByID(msg.CollectionID)
if err != nil {
log.Error(err.Error())
q.publishFailedRetrieveResult(msg, err.Error())
return
}
if msg.BeginTs() >= collection.getReleaseTime() {
err := errors.New("retrieve failed, collection has been released, msgID = " +
fmt.Sprintln(msg.ID()) +
", collectionID = " +
fmt.Sprintln(msg.CollectionID))
log.Error(err.Error())
q.publishFailedRetrieveResult(msg, err.Error())
return
}
serviceTime := q.getServiceableTime()
if msg.BeginTs() > serviceTime {
bt, _ := tsoutil.ParseTS(msg.BeginTs())
st, _ := tsoutil.ParseTS(serviceTime)
log.Debug("query node::receiveRetrieveMsg: add to unsolvedMsg",
zap.Any("collectionID", q.collectionID),
zap.Any("sm.BeginTs", bt),
zap.Any("serviceTime", st),
zap.Any("delta seconds", (msg.BeginTs()-serviceTime)/(1000*1000*1000)),
zap.Any("msgID", msg.ID()),
)
q.addToUnsolvedRetrieveMsg(msg)
sp.LogFields(
oplog.String("send to unsolved buffer", "send to unsolved buffer"),
oplog.Object("begin ts", bt),
oplog.Object("serviceTime", st),
oplog.Float64("delta seconds", float64(msg.BeginTs()-serviceTime)/(1000.0*1000.0*1000.0)),
)
sp.Finish()
return
}
log.Debug("doing retrieve in receiveRetrieveMsg...",
zap.Int64("collectionID", msg.CollectionID),
zap.Int64("msgID", msg.ID()),
)
err = q.retrieve(msg)
if err != nil {
log.Error(err.Error())
log.Debug("do retrieve failed in receiveRetrieveMsg, prepare to publish failed retrieve result",
zap.Int64("collectionID", msg.CollectionID),
zap.Int64("msgID", msg.ID()),
)
q.publishFailedRetrieveResult(msg, err.Error())
}
log.Debug("do retrieve done in receiveRetrieve",
zap.Int64("collectionID", msg.CollectionID),
zap.Int64("msgID", msg.ID()),
)
sp.Finish()
}
func (q *queryCollection) receiveSearch(msg *msgstream.SearchMsg) {
if msg.CollectionID != q.collectionID {
log.Debug("not target collection search request",
zap.Any("collectionID", msg.CollectionID),
zap.Int64("msgID", msg.ID()),
)
return
}
log.Debug("consume search message", log.Debug("consume search message",
zap.Any("collectionID", msg.CollectionID), zap.Any("collectionID", collectionID),
zap.Int64("msgID", msg.ID()), zap.Int64("msgID", msg.ID()),
) )
default:
err := fmt.Errorf("receive invalid msgType = %d", msgType)
log.Error(err.Error())
return
}
if collectionID != q.collectionID {
log.Error("not target collection query request",
zap.Any("collectionID", q.collectionID),
zap.Int64("target collectionID", collectionID),
zap.Int64("msgID", msg.ID()),
)
return
}
sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) sp, ctx := trace.StartSpanFromContext(msg.TraceCtx())
msg.SetTraceCtx(ctx) msg.SetTraceCtx(ctx)
// check if collection has been released // check if collection has been released
collection, err := q.historical.replica.getCollectionByID(msg.CollectionID) collection, err := q.historical.replica.getCollectionByID(collectionID)
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
q.publishFailedSearchResult(msg, err.Error()) err = q.publishFailedQueryResult(msg, err.Error())
if err != nil {
log.Error(err.Error())
} else {
log.Debug("do query failed in receiveQueryMsg, publish failed query result",
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()),
zap.String("msgType", msgTypeStr),
)
}
return return
} }
if msg.BeginTs() >= collection.getReleaseTime() { if msg.BeginTs() >= collection.getReleaseTime() {
err := errors.New("search failed, collection has been released, msgID = " + err = fmt.Errorf("retrieve failed, collection has been released, msgID = %d, collectionID = %d", msg.ID(), collectionID)
fmt.Sprintln(msg.ID()) +
", collectionID = " +
fmt.Sprintln(msg.CollectionID))
log.Error(err.Error()) log.Error(err.Error())
q.publishFailedSearchResult(msg, err.Error()) err = q.publishFailedQueryResult(msg, err.Error())
if err != nil {
log.Error(err.Error())
} else {
log.Debug("do query failed in receiveQueryMsg, publish failed query result",
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()),
zap.String("msgType", msgTypeStr),
)
}
return return
} }
serviceTime := q.getServiceableTime() serviceTime := q.getServiceableTime()
if msg.BeginTs() > serviceTime {
bt, _ := tsoutil.ParseTS(msg.BeginTs()) bt, _ := tsoutil.ParseTS(msg.BeginTs())
st, _ := tsoutil.ParseTS(serviceTime) st, _ := tsoutil.ParseTS(serviceTime)
if msg.BeginTs() > serviceTime { log.Debug("query node::receiveQueryMsg: add to unsolvedMsg",
log.Debug("query node::receiveSearchMsg: add to unsolvedMsg",
zap.Any("collectionID", q.collectionID), zap.Any("collectionID", q.collectionID),
zap.Any("sm.BeginTs", bt), zap.Any("sm.BeginTs", bt),
zap.Any("serviceTime", st), zap.Any("serviceTime", st),
zap.Any("delta seconds", (msg.BeginTs()-serviceTime)/(1000*1000*1000)), zap.Any("delta seconds", (msg.BeginTs()-serviceTime)/(1000*1000*1000)),
zap.Any("msgID", msg.ID()), zap.Any("msgID", msg.ID()),
zap.String("msgType", msgTypeStr),
) )
q.addToUnsolvedMsg(msg) q.addToUnsolvedMsg(msg)
sp.LogFields( sp.LogFields(
@ -405,36 +340,49 @@ func (q *queryCollection) receiveSearch(msg *msgstream.SearchMsg) {
sp.Finish() sp.Finish()
return return
} }
log.Debug("doing search in receiveSearchMsg...", log.Debug("doing query in receiveQueryMsg...",
zap.Int64("collectionID", msg.CollectionID), zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()), zap.Int64("msgID", msg.ID()),
zap.Any("serviceTime_l", serviceTime), zap.String("msgType", msgTypeStr),
zap.Any("searchTime_l", msg.BeginTs()),
zap.Any("serviceTime_p", st),
zap.Any("searchTime_p", bt),
) )
switch msgType {
case commonpb.MsgType_Retrieve:
err = q.retrieve(msg)
case commonpb.MsgType_Search:
err = q.search(msg) err = q.search(msg)
default:
err := fmt.Errorf("receive invalid msgType = %d", msgType)
log.Error(err.Error())
return
}
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
log.Debug("do search failed in receiveSearchMsg, prepare to publish failed search result", err = q.publishFailedQueryResult(msg, err.Error())
zap.Int64("collectionID", msg.CollectionID), if err != nil {
log.Error(err.Error())
} else {
log.Debug("do query failed in receiveQueryMsg, publish failed query result",
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()), zap.Int64("msgID", msg.ID()),
zap.String("msgType", msgTypeStr),
) )
q.publishFailedSearchResult(msg, err.Error())
} }
log.Debug("do search done in receiveSearch", }
zap.Int64("collectionID", msg.CollectionID), log.Debug("do query done in receiveQueryMsg",
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()), zap.Int64("msgID", msg.ID()),
zap.String("msgType", msgTypeStr),
) )
sp.Finish() sp.Finish()
} }
func (q *queryCollection) doUnsolvedMsgSearch() { func (q *queryCollection) doUnsolvedQueryMsg() {
log.Debug("starting doUnsolvedMsgSearch...", zap.Any("collectionID", q.collectionID)) log.Debug("starting doUnsolvedMsg...", zap.Any("collectionID", q.collectionID))
for { for {
select { select {
case <-q.releaseCtx.Done(): case <-q.releaseCtx.Done():
log.Debug("stop searchCollection's doUnsolvedMsgSearch", zap.Int64("collectionID", q.collectionID)) log.Debug("stop Collection's doUnsolvedMsg", zap.Int64("collectionID", q.collectionID))
return return
default: default:
//time.Sleep(10 * time.Millisecond) //time.Sleep(10 * time.Millisecond)
@ -445,64 +393,80 @@ func (q *queryCollection) doUnsolvedMsgSearch() {
zap.Any("tSafe", st)) zap.Any("tSafe", st))
q.setServiceableTime(serviceTime) q.setServiceableTime(serviceTime)
//log.Debug("query node::doUnsolvedMsgSearch: setServiceableTime", //log.Debug("query node::doUnsolvedMsg: setServiceableTime",
// zap.Any("serviceTime", st), // zap.Any("serviceTime", st),
//) //)
searchMsg := make([]*msgstream.SearchMsg, 0) unSolvedMsg := make([]msgstream.TsMsg, 0)
tempMsg := q.popAllUnsolvedMsg() tempMsg := q.popAllUnsolvedMsg()
for _, sm := range tempMsg { for _, m := range tempMsg {
bt, _ := tsoutil.ParseTS(sm.EndTs()) bt, _ := tsoutil.ParseTS(m.EndTs())
st, _ = tsoutil.ParseTS(serviceTime) st, _ = tsoutil.ParseTS(serviceTime)
log.Debug("get search message from unsolvedMsg", log.Debug("get query message from unsolvedMsg",
zap.Int64("collectionID", sm.CollectionID), zap.Int64("collectionID", q.collectionID),
zap.Int64("msgID", sm.ID()), zap.Int64("msgID", m.ID()),
zap.Any("reqTime_p", bt), zap.Any("reqTime_p", bt),
zap.Any("serviceTime_p", st), zap.Any("serviceTime_p", st),
zap.Any("reqTime_l", sm.EndTs()), zap.Any("reqTime_l", m.EndTs()),
zap.Any("serviceTime_l", serviceTime), zap.Any("serviceTime_l", serviceTime),
) )
if sm.EndTs() <= serviceTime { if m.EndTs() <= serviceTime {
searchMsg = append(searchMsg, sm) unSolvedMsg = append(unSolvedMsg, m)
continue continue
} }
log.Debug("query node::doUnsolvedMsgSearch: add to unsolvedMsg", log.Debug("query node::doUnsolvedMsg: add to unsolvedMsg",
zap.Any("collectionID", q.collectionID), zap.Any("collectionID", q.collectionID),
zap.Any("sm.BeginTs", bt), zap.Any("sm.BeginTs", bt),
zap.Any("serviceTime", st), zap.Any("serviceTime", st),
zap.Any("delta seconds", (sm.BeginTs()-serviceTime)/(1000*1000*1000)), zap.Any("delta seconds", (m.BeginTs()-serviceTime)/(1000*1000*1000)),
zap.Any("msgID", sm.ID()), zap.Any("msgID", m.ID()),
) )
q.addToUnsolvedMsg(sm) q.addToUnsolvedMsg(m)
} }
if len(searchMsg) <= 0 { if len(unSolvedMsg) <= 0 {
continue continue
} }
for _, sm := range searchMsg { for _, m := range unSolvedMsg {
sp, ctx := trace.StartSpanFromContext(sm.TraceCtx()) msgType := m.Type()
sm.SetTraceCtx(ctx) var err error
log.Debug("doing search in doUnsolvedMsgSearch...", sp, ctx := trace.StartSpanFromContext(m.TraceCtx())
zap.Int64("collectionID", sm.CollectionID), m.SetTraceCtx(ctx)
zap.Int64("msgID", sm.ID()), log.Debug("doing search in doUnsolvedMsg...",
zap.Int64("collectionID", q.collectionID),
zap.Int64("msgID", m.ID()),
) )
err := q.search(sm) switch msgType {
case commonpb.MsgType_Retrieve:
err = q.retrieve(m)
case commonpb.MsgType_Search:
err = q.search(m)
default:
err := fmt.Errorf("receive invalid msgType = %d", msgType)
log.Error(err.Error())
return
}
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
log.Debug("do search failed in doUnsolvedMsgSearch, prepare to publish failed search result", err = q.publishFailedQueryResult(m, err.Error())
zap.Int64("collectionID", sm.CollectionID), if err != nil {
zap.Int64("msgID", sm.ID()), log.Error(err.Error())
} else {
log.Debug("do query failed in doUnsolvedMsg, publish failed query result",
zap.Int64("collectionID", q.collectionID),
zap.Int64("msgID", m.ID()),
) )
q.publishFailedSearchResult(sm, err.Error()) }
} }
sp.Finish() sp.Finish()
log.Debug("do search done in doUnsolvedMsgSearch", log.Debug("do query done in doUnsolvedMsg",
zap.Int64("collectionID", sm.CollectionID), zap.Int64("collectionID", q.collectionID),
zap.Int64("msgID", sm.ID()), zap.Int64("msgID", m.ID()),
) )
} }
log.Debug("doUnsolvedMsgSearch, do search done", zap.Int("num of searchMsg", len(searchMsg))) log.Debug("doUnsolvedMsg: do query done", zap.Int("num of query msg", len(unSolvedMsg)))
} }
} }
} }
@ -731,7 +695,8 @@ func translateHits(schema *typeutil.SchemaHelper, fieldIDs []int64, rawHits [][]
// TODO:: cache map[dsl]plan // TODO:: cache map[dsl]plan
// TODO: reBatched search requests // TODO: reBatched search requests
func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error { func (q *queryCollection) search(msg msgstream.TsMsg) error {
searchMsg := msg.(*msgstream.SearchMsg)
sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx()) sp, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
defer sp.Finish() defer sp.Finish()
searchMsg.SetTraceCtx(ctx) searchMsg.SetTraceCtx(ctx)
@ -873,7 +838,7 @@ func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error {
zap.Any("vChannels", collection.getVChannels()), zap.Any("vChannels", collection.getVChannels()),
zap.Any("sealedSegmentSearched", sealedSegmentSearched), zap.Any("sealedSegmentSearched", sealedSegmentSearched),
) )
err = q.publishSearchResult(searchResultMsg, searchMsg.CollectionID) err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID)
if err != nil { if err != nil {
return err return err
} }
@ -993,7 +958,7 @@ func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error {
// fmt.Println(testHits.IDs) // fmt.Println(testHits.IDs)
// fmt.Println(testHits.Scores) // fmt.Println(testHits.Scores)
//} //}
err = q.publishSearchResult(searchResultMsg, searchMsg.CollectionID) err = q.publishQueryResult(searchResultMsg, searchMsg.CollectionID)
if err != nil { if err != nil {
return err return err
} }
@ -1008,147 +973,14 @@ func (q *queryCollection) search(searchMsg *msgstream.SearchMsg) error {
return nil return nil
} }
func (q *queryCollection) publishSearchResult(msg msgstream.TsMsg, collectionID UniqueID) error { func (q *queryCollection) retrieve(msg msgstream.TsMsg) error {
log.Debug("publishing search result...",
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()),
)
span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
defer span.Finish()
msg.SetTraceCtx(ctx)
msgPack := msgstream.MsgPack{}
msgPack.Msgs = append(msgPack.Msgs, msg)
err := q.queryResultMsgStream.Produce(&msgPack)
if err != nil {
log.Error("publishing search result failed, err = "+err.Error(),
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()),
)
} else {
log.Debug("publish search result done",
zap.Int64("collectionID", collectionID),
zap.Int64("msgID", msg.ID()),
)
}
return err
}
func (q *queryCollection) publishFailedSearchResult(searchMsg *msgstream.SearchMsg, errMsg string) {
span, ctx := trace.StartSpanFromContext(searchMsg.TraceCtx())
defer span.Finish()
searchMsg.SetTraceCtx(ctx)
//log.Debug("Public fail SearchResult!")
msgPack := msgstream.MsgPack{}
resultChannelInt := 0
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}},
SearchResults: internalpb.SearchResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_SearchResult,
MsgID: searchMsg.Base.MsgID,
Timestamp: searchMsg.Base.Timestamp,
SourceID: searchMsg.Base.SourceID,
},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
ResultChannelID: searchMsg.ResultChannelID,
},
}
msgPack.Msgs = append(msgPack.Msgs, searchResultMsg)
err := q.queryResultMsgStream.Produce(&msgPack)
if err != nil {
log.Error("publish FailedSearchResult failed" + err.Error())
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
func (q *queryCollection) doUnsolvedMsgRetrieve() {
log.Debug("starting doUnsolvedMsgRetrieve...", zap.Any("collectionID", q.collectionID))
for {
select {
case <-q.releaseCtx.Done():
log.Debug("stop retrieveCollection's doUnsolvedMsgRertieve", zap.Int64("collectionID", q.collectionID))
return
default:
//time.Sleep(10 * time.Millisecond)
serviceTime := q.waitNewTSafe()
st, _ := tsoutil.ParseTS(serviceTime)
log.Debug("get tSafe from flow graph",
zap.Int64("collectionID", q.collectionID),
zap.Any("tSafe", st))
q.setServiceableTime(serviceTime)
//log.Debug("query node::doUnsolvedMsgSearch: setServiceableTime",
// zap.Any("serviceTime", st),
//)
retrieveMsg := make([]*msgstream.RetrieveMsg, 0)
tempMsg := q.popAllUnsolvedRetrieveMsg()
for _, rm := range tempMsg {
bt, _ := tsoutil.ParseTS(rm.EndTs())
st, _ = tsoutil.ParseTS(serviceTime)
log.Debug("get retrieve message from unsolvedMsg",
zap.Int64("collectionID", rm.CollectionID),
zap.Int64("msgID", rm.ID()),
zap.Any("reqTime_p", bt),
zap.Any("serviceTime_p", st),
zap.Any("reqTime_l", rm.EndTs()),
zap.Any("serviceTime_l", serviceTime),
)
if rm.EndTs() <= serviceTime {
retrieveMsg = append(retrieveMsg, rm)
continue
}
log.Debug("query node::doUnsolvedMsgRetrieve: add to unsolvedMsg",
zap.Any("collectionID", q.collectionID),
zap.Any("sm.BeginTs", bt),
zap.Any("serviceTime", st),
zap.Any("delta seconds", (rm.BeginTs()-serviceTime)/(1000*1000*1000)),
zap.Any("msgID", rm.ID()),
)
q.addToUnsolvedRetrieveMsg(rm)
}
if len(retrieveMsg) <= 0 {
continue
}
for _, rm := range retrieveMsg {
sp, ctx := trace.StartSpanFromContext(rm.TraceCtx())
rm.SetTraceCtx(ctx)
log.Debug("doing search in doUnsolvedMsgRetrieve...",
zap.Int64("collectionID", rm.CollectionID),
zap.Int64("msgID", rm.ID()),
)
err := q.retrieve(rm)
if err != nil {
log.Error(err.Error())
log.Debug("do retrieve failed in doUnsolvedMsgSearch, prepare to publish failed retrieve result",
zap.Int64("collectionID", rm.CollectionID),
zap.Int64("msgID", rm.ID()),
)
q.publishFailedRetrieveResult(rm, err.Error())
}
sp.Finish()
log.Debug("do retrieve done in doUnsolvedMsgSearch",
zap.Int64("collectionID", rm.CollectionID),
zap.Int64("msgID", rm.ID()),
)
}
log.Debug("doUnsolvedMsgRetrieve, do retrieve done", zap.Int("num of retrieveMsg", len(retrieveMsg)))
}
}
}
func (q *queryCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
// TODO(yukun) // TODO(yukun)
// step 1: get retrieve object and defer destruction // step 1: get retrieve object and defer destruction
// step 2: for each segment, call retrieve to get ids proto buffer // step 2: for each segment, call retrieve to get ids proto buffer
// step 3: merge all proto in go // step 3: merge all proto in go
// step 4: publish results // step 4: publish results
// retrieveProtoBlob, err := proto.Marshal(&retrieveMsg.RetrieveRequest) // retrieveProtoBlob, err := proto.Marshal(&retrieveMsg.RetrieveRequest)
retrieveMsg := msg.(*msgstream.RetrieveMsg)
sp, ctx := trace.StartSpanFromContext(retrieveMsg.TraceCtx()) sp, ctx := trace.StartSpanFromContext(retrieveMsg.TraceCtx())
defer sp.Finish() defer sp.Finish()
retrieveMsg.SetTraceCtx(ctx) retrieveMsg.SetTraceCtx(ctx)
@ -1237,8 +1069,6 @@ func (q *queryCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
} }
} }
log.Debug("1111", zap.Any("len of mergeList", len(mergeList)))
result, err := mergeRetrieveResults(mergeList) result, err := mergeRetrieveResults(mergeList)
if err != nil { if err != nil {
return err return err
@ -1263,15 +1093,16 @@ func (q *queryCollection) retrieve(retrieveMsg *msgstream.RetrieveMsg) error {
GlobalSealedSegmentIDs: sealedSegmentRetrieved, GlobalSealedSegmentIDs: sealedSegmentRetrieved,
}, },
} }
log.Debug("QueryNode RetrieveResultMsg",
err3 := q.publishQueryResult(retrieveResultMsg, retrieveMsg.CollectionID)
if err3 != nil {
return err3
}
log.Debug("QueryNode publish RetrieveResultMsg",
zap.Any("vChannels", collection.getVChannels()), zap.Any("vChannels", collection.getVChannels()),
zap.Any("collectionID", collection.ID()), zap.Any("collectionID", collection.ID()),
zap.Any("sealedSegmentRetrieved", sealedSegmentRetrieved), zap.Any("sealedSegmentRetrieved", sealedSegmentRetrieved),
) )
err3 := q.publishRetrieveResult(retrieveResultMsg, retrieveMsg.CollectionID)
if err3 != nil {
return err3
}
return nil return nil
} }
@ -1308,10 +1139,7 @@ func mergeRetrieveResults(dataArr []*segcorepb.RetrieveResults) (*segcorepb.Retr
return final, nil return final, nil
} }
func (q *queryCollection) publishRetrieveResult(msg msgstream.TsMsg, collectionID UniqueID) error { func (q *queryCollection) publishQueryResult(msg msgstream.TsMsg, collectionID UniqueID) error {
log.Debug("publishing retrieve result...",
zap.Int64("msgID", msg.ID()),
zap.Int64("collectionID", collectionID))
span, ctx := trace.StartSpanFromContext(msg.TraceCtx()) span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
defer span.Finish() defer span.Finish()
msg.SetTraceCtx(ctx) msg.SetTraceCtx(ctx)
@ -1320,38 +1148,59 @@ func (q *queryCollection) publishRetrieveResult(msg msgstream.TsMsg, collectionI
err := q.queryResultMsgStream.Produce(&msgPack) err := q.queryResultMsgStream.Produce(&msgPack)
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
} else {
log.Debug("publish retrieve result done",
zap.Int64("msgID", msg.ID()),
zap.Int64("collectionID", collectionID))
} }
return err return err
} }
func (q *queryCollection) publishFailedRetrieveResult(retrieveMsg *msgstream.RetrieveMsg, errMsg string) error { func (q *queryCollection) publishFailedQueryResult(msg msgstream.TsMsg, errMsg string) error {
span, ctx := trace.StartSpanFromContext(retrieveMsg.TraceCtx()) msgType := msg.Type()
span, ctx := trace.StartSpanFromContext(msg.TraceCtx())
defer span.Finish() defer span.Finish()
retrieveMsg.SetTraceCtx(ctx) msg.SetTraceCtx(ctx)
msgPack := msgstream.MsgPack{} msgPack := msgstream.MsgPack{}
resultChannelInt := 0 resultChannelInt := 0
baseMsg := msgstream.BaseMsg{
HashValues: []uint32{uint32(resultChannelInt)},
}
baseResult := &commonpb.MsgBase{
MsgID: msg.ID(),
Timestamp: msg.BeginTs(),
SourceID: msg.SourceID(),
}
switch msgType {
case commonpb.MsgType_Retrieve:
retrieveMsg := msg.(*msgstream.RetrieveMsg)
baseResult.MsgType = commonpb.MsgType_RetrieveResult
retrieveResultMsg := &msgstream.RetrieveResultMsg{ retrieveResultMsg := &msgstream.RetrieveResultMsg{
BaseMsg: msgstream.BaseMsg{HashValues: []uint32{uint32(resultChannelInt)}}, BaseMsg: baseMsg,
RetrieveResults: internalpb.RetrieveResults{ RetrieveResults: internalpb.RetrieveResults{
Base: &commonpb.MsgBase{ Base: baseResult,
MsgType: commonpb.MsgType_RetrieveResult,
MsgID: retrieveMsg.Base.MsgID,
Timestamp: retrieveMsg.Base.Timestamp,
SourceID: retrieveMsg.Base.SourceID,
},
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg}, Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
ResultChannelID: retrieveMsg.ResultChannelID, ResultChannelID: retrieveMsg.ResultChannelID,
Ids: nil, Ids: nil,
FieldsData: nil, FieldsData: nil,
}, },
} }
msgPack.Msgs = append(msgPack.Msgs, retrieveResultMsg) msgPack.Msgs = append(msgPack.Msgs, retrieveResultMsg)
case commonpb.MsgType_Search:
searchMsg := msg.(*msgstream.SearchMsg)
baseResult.MsgType = commonpb.MsgType_SearchResult
searchResultMsg := &msgstream.SearchResultMsg{
BaseMsg: baseMsg,
SearchResults: internalpb.SearchResults{
Base: baseResult,
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: errMsg},
ResultChannelID: searchMsg.ResultChannelID,
},
}
msgPack.Msgs = append(msgPack.Msgs, searchResultMsg)
default:
return fmt.Errorf("publish invalid msgType %d", msgType)
}
err := q.queryResultMsgStream.Produce(&msgPack) err := q.queryResultMsgStream.Produce(&msgPack)
if err != nil { if err != nil {
return err return err