diff --git a/internal/querynode/query_collection.go b/internal/querynode/query_collection.go index ad170421c3..4be5cbbcd2 100644 --- a/internal/querynode/query_collection.go +++ b/internal/querynode/query_collection.go @@ -222,11 +222,17 @@ func (q *queryCollection) consumeQuery() { for _, msg := range msgPack.Msgs { switch sm := msg.(type) { case *msgstream.SearchMsg: - q.receiveQueryMsg(sm) + err := q.receiveQueryMsg(sm) + if err != nil { + log.Warn(err.Error()) + } case *msgstream.LoadBalanceSegmentsMsg: q.loadBalance(sm) case *msgstream.RetrieveMsg: - q.receiveQueryMsg(sm) + err := q.receiveQueryMsg(sm) + if err != nil { + log.Warn(err.Error()) + } default: log.Warn("unsupported msg type in search channel", zap.Any("msg", sm)) } @@ -266,7 +272,7 @@ func (q *queryCollection) loadBalance(msg *msgstream.LoadBalanceSegmentsMsg) { // zap.Int("num of segment", len(msg.Infos))) } -func (q *queryCollection) receiveQueryMsg(msg queryMsg) { +func (q *queryCollection) receiveQueryMsg(msg queryMsg) error { msgType := msg.Type() var collectionID UniqueID var msgTypeStr string @@ -288,8 +294,7 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) { //) default: err := fmt.Errorf("receive invalid msgType = %d", msgType) - log.Warn(err.Error()) - return + return err } if collectionID != q.collectionID { //log.Warn("not target collection query request", @@ -297,7 +302,8 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) { // zap.Int64("target collectionID", collectionID), // zap.Int64("msgID", msg.ID()), //) - return + err := fmt.Errorf("not target collection query request, collectionID = %d, targetCollectionID = %d, msgID = %d", q.collectionID, collectionID, msg.ID()) + return err } sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) @@ -307,38 +313,36 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) { // check if collection has been released collection, err := q.historical.replica.getCollectionByID(collectionID) if err != nil { - log.Warn(err.Error()) - err = q.publishFailedQueryResult(msg, err.Error()) - if err != nil { - log.Warn(err.Error()) - } else { - log.Debug("do query failed in receiveQueryMsg, publish failed query result", - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", msg.ID()), - zap.String("msgType", msgTypeStr), - ) + publishErr := q.publishFailedQueryResult(msg, err.Error()) + if publishErr != nil { + finalErr := fmt.Errorf("first err = %s, second err = %s", err, publishErr) + return finalErr } - return + log.Debug("do query failed in receiveQueryMsg, publish failed query result", + zap.Int64("collectionID", collectionID), + zap.Int64("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), + ) + return err } guaranteeTs := msg.GuaranteeTs() if guaranteeTs >= collection.getReleaseTime() { err = fmt.Errorf("retrieve failed, collection has been released, msgID = %d, collectionID = %d", msg.ID(), collectionID) - log.Warn(err.Error()) - err = q.publishFailedQueryResult(msg, err.Error()) - if err != nil { - log.Warn(err.Error()) - } else { - log.Debug("do query failed in receiveQueryMsg, publish failed query result", - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", msg.ID()), - zap.String("msgType", msgTypeStr), - ) + publishErr := q.publishFailedQueryResult(msg, err.Error()) + if publishErr != nil { + finalErr := fmt.Errorf("first err = %s, second err = %s", err, publishErr) + return finalErr } - return + log.Debug("do query failed in receiveQueryMsg, publish failed query result", + zap.Int64("collectionID", collectionID), + zap.Int64("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), + ) + return err } serviceTime := q.getServiceableTime() - if guaranteeTs > serviceTime { + if guaranteeTs > serviceTime && len(collection.getVChannels()) > 0 { gt, _ := tsoutil.ParseTS(guaranteeTs) st, _ := tsoutil.ParseTS(serviceTime) log.Debug("query node::receiveQueryMsg: add to unsolvedMsg", @@ -357,7 +361,7 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) { oplog.Float64("delta seconds", float64(guaranteeTs-serviceTime)/(1000.0*1000.0*1000.0)), ) sp.Finish() - return + return nil } tr.Record("get searchable time done") @@ -372,24 +376,23 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) { case commonpb.MsgType_Search: err = q.search(msg) default: - err := fmt.Errorf("receive invalid msgType = %d", msgType) - log.Warn(err.Error()) - return + err = fmt.Errorf("receive invalid msgType = %d", msgType) + return err } tr.Record("operation done") if err != nil { - log.Warn(err.Error()) - err = q.publishFailedQueryResult(msg, err.Error()) - if err != nil { - log.Warn(err.Error()) - } else { - log.Debug("do query failed in receiveQueryMsg, publish failed query result", - zap.Int64("collectionID", collectionID), - zap.Int64("msgID", msg.ID()), - zap.String("msgType", msgTypeStr), - ) + publishErr := q.publishFailedQueryResult(msg, err.Error()) + if publishErr != nil { + finalErr := fmt.Errorf("first err = %s, second err = %s", err, publishErr) + return finalErr } + log.Debug("do query failed in receiveQueryMsg, publish failed query result", + zap.Int64("collectionID", collectionID), + zap.Int64("msgID", msg.ID()), + zap.String("msgType", msgTypeStr), + ) + return err } log.Debug("do query done in receiveQueryMsg", zap.Int64("collectionID", collectionID), @@ -398,6 +401,7 @@ func (q *queryCollection) receiveQueryMsg(msg queryMsg) { ) tr.Elapse("all done") sp.Finish() + return nil } func (q *queryCollection) doUnsolvedQueryMsg() { diff --git a/internal/querynode/query_collection_test.go b/internal/querynode/query_collection_test.go new file mode 100644 index 0000000000..c8cb455766 --- /dev/null +++ b/internal/querynode/query_collection_test.go @@ -0,0 +1,130 @@ +package querynode + +import ( + "context" + "encoding/binary" + "math" + "math/rand" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/assert" + + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/msgstream" + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" +) + +func TestQueryCollection_withoutVChannel(t *testing.T) { + m := map[string]interface{}{ + "PulsarAddress": Params.PulsarAddress, + "ReceiveBufSize": 1024, + "PulsarBufSize": 1024} + factory := msgstream.NewPmsFactory() + err := factory.SetParams(m) + assert.Nil(t, err) + etcdKV, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + assert.Nil(t, err) + + schema := genTestCollectionSchema(0, false, 2) + historical := newHistorical(context.Background(), nil, nil, factory, etcdKV) + + //add a segment to historical data + err = historical.replica.addCollection(0, schema) + assert.Nil(t, err) + err = historical.replica.addPartition(0, 1) + assert.Nil(t, err) + err = historical.replica.addSegment(2, 1, 0, "testChannel", segmentTypeSealed, true) + assert.Nil(t, err) + segment, err := historical.replica.getSegmentByID(2) + assert.Nil(t, err) + const N = 2 + rowID := []int32{1, 2} + timeStamp := []int64{0, 1} + age := []int64{10, 20} + vectorData := []float32{1, 2, 3, 4} + err = segment.segmentLoadFieldData(0, N, rowID) + assert.Nil(t, err) + err = segment.segmentLoadFieldData(1, N, timeStamp) + assert.Nil(t, err) + err = segment.segmentLoadFieldData(101, N, age) + assert.Nil(t, err) + err = segment.segmentLoadFieldData(100, N, vectorData) + assert.Nil(t, err) + + //create a streaming + streaming := newStreaming(context.Background(), factory, etcdKV) + err = streaming.replica.addCollection(0, schema) + assert.Nil(t, err) + err = streaming.replica.addPartition(0, 1) + assert.Nil(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + queryCollection := newQueryCollection(ctx, cancel, 0, historical, streaming, factory, nil, nil) + + producerChannels := []string{"testResultChannel"} + queryCollection.queryResultMsgStream.AsProducer(producerChannels) + + dim := 2 + // generate search rawData + var vec = make([]float32, dim) + for i := 0; i < dim; i++ { + vec[i] = rand.Float32() + } + dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }" + var searchRawData1 []byte + var searchRawData2 []byte + for i, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*2))) + searchRawData1 = append(searchRawData1, buf...) + } + for i, ele := range vec { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, math.Float32bits(ele+float32(i*4))) + searchRawData2 = append(searchRawData2, buf...) + } + + // generate placeholder + placeholderValue := milvuspb.PlaceholderValue{ + Tag: "$0", + Type: milvuspb.PlaceholderType_FloatVector, + Values: [][]byte{searchRawData1, searchRawData2}, + } + placeholderGroup := milvuspb.PlaceholderGroup{ + Placeholders: []*milvuspb.PlaceholderValue{&placeholderValue}, + } + placeGroupByte, err := proto.Marshal(&placeholderGroup) + assert.Nil(t, err) + + queryMsg := &msgstream.SearchMsg{ + BaseMsg: msgstream.BaseMsg{ + Ctx: ctx, + BeginTimestamp: 10, + EndTimestamp: 10, + }, + SearchRequest: internalpb.SearchRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Search, + MsgID: 1, + Timestamp: Timestamp(10), + SourceID: 1, + }, + CollectionID: 0, + ResultChannelID: "testResultChannel", + Dsl: dslString, + PlaceholderGroup: placeGroupByte, + TravelTimestamp: 10, + GuaranteeTimestamp: 10, + }, + } + err = queryCollection.receiveQueryMsg(queryMsg) + assert.Nil(t, err) + + queryCollection.cancel() + queryCollection.close() + historical.close() + streaming.close() +} diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index a6b0b8593d..fd33d73c23 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -48,7 +48,7 @@ func setup() { Params.MetaRootPath = "/etcd/test/root/querynode" } -func genTestCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.CollectionInfo { +func genTestCollectionSchema(collectionID UniqueID, isBinary bool, dim int) *schemapb.CollectionSchema { var fieldVec schemapb.FieldSchema if isBinary { fieldVec = schemapb.FieldSchema{ @@ -59,7 +59,7 @@ func genTestCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.Collect TypeParams: []*commonpb.KeyValuePair{ { Key: "dim", - Value: "128", + Value: strconv.Itoa(dim * 8), }, }, IndexParams: []*commonpb.KeyValuePair{ @@ -78,7 +78,7 @@ func genTestCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.Collect TypeParams: []*commonpb.KeyValuePair{ { Key: "dim", - Value: "16", + Value: strconv.Itoa(dim), }, }, IndexParams: []*commonpb.KeyValuePair{ @@ -97,16 +97,22 @@ func genTestCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.Collect DataType: schemapb.DataType_Int32, } - schema := schemapb.CollectionSchema{ + schema := &schemapb.CollectionSchema{ AutoID: true, Fields: []*schemapb.FieldSchema{ &fieldVec, &fieldInt, }, } + return schema +} + +func genTestCollectionMeta(collectionID UniqueID, isBinary bool) *etcdpb.CollectionInfo { + schema := genTestCollectionSchema(collectionID, isBinary, 16) + collectionMeta := etcdpb.CollectionInfo{ ID: collectionID, - Schema: &schema, + Schema: schema, CreateTime: Timestamp(0), PartitionIDs: []UniqueID{defaultPartitionID}, }