diff --git a/internal/proxy/task.go b/internal/proxy/task.go index d80d61bafb..02e72232b8 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -2099,6 +2099,14 @@ func (qt *queryTask) getChannels() ([]pChan, error) { return nil, err } + _, err = qt.chMgr.getChannels(collID) + if err != nil { + err := qt.chMgr.createDMLMsgStream(collID) + if err != nil { + return nil, err + } + } + return qt.chMgr.getChannels(collID) } @@ -2119,6 +2127,7 @@ func (qt *queryTask) getVChannels() ([]vChan, error) { return qt.chMgr.getVChannels(collID) } +/* not used func parseIdsFromExpr(exprStr string, schema *typeutil.SchemaHelper) ([]int64, error) { expr, err := parseQueryExpr(schema, exprStr) if err != nil { @@ -2146,6 +2155,7 @@ func parseIdsFromExpr(exprStr string, schema *typeutil.SchemaHelper) ([]int64, e return nil, errors.New("not top level term") } } +*/ func IDs2Expr(fieldName string, ids []int64) string { idsStr := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids)), ", "), "[]") @@ -2157,14 +2167,6 @@ func (qt *queryTask) PreExecute(ctx context.Context) error { qt.Base.SourceID = Params.ProxyID collectionName := qt.query.CollectionName - collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) - if err != nil { - log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), - zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query")) - return err - } - log.Info("Get collection id by name.", zap.Any("collectionName", collectionName), - zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query")) if err := ValidateCollectionName(qt.query.CollectionName); err != nil { log.Debug("Invalid collection name.", zap.Any("collectionName", collectionName), @@ -2174,6 +2176,15 @@ func (qt *queryTask) PreExecute(ctx context.Context) error { log.Info("Validate collection name.", zap.Any("collectionName", collectionName), zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query")) + collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) + if err != nil { + log.Debug("Failed to get collection id.", zap.Any("collectionName", collectionName), + zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query")) + return err + } + log.Info("Get collection id by name.", zap.Any("collectionName", collectionName), + zap.Any("requestID", qt.Base.MsgID), zap.Any("requestType", "query")) + for _, tag := range qt.query.PartitionNames { if err := ValidatePartitionTag(tag, false); err != nil { log.Debug("Invalid partition name.", zap.Any("partitionName", tag), @@ -2215,10 +2226,7 @@ func (qt *queryTask) PreExecute(ctx context.Context) error { return fmt.Errorf("collection %v was not loaded into memory", collectionName) } - schema, err := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName) - if err != nil { // err is not nil if collection not exists - return err - } + schema, _ := globalMetaCache.GetCollectionSchema(ctx, qt.query.CollectionName) // schemaHelper, err := typeutil.CreateSchemaHelper(schema) // if err != nil { // return err diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index ce275959f8..476a38841a 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -2408,3 +2408,305 @@ func TestSearchTask_Execute(t *testing.T) { assert.Error(t, task.Execute(ctx)) // TODO(dragondriver): cover getDQLStream } + +func TestQueryTask_all(t *testing.T) { + var err error + + Params.Init() + Params.RetrieveResultChannelNames = []string{funcutil.GenRandomStr()} + + rc := NewRootCoordMock() + rc.Start() + defer rc.Stop() + + ctx := context.Background() + + err = InitMetaCache(rc) + assert.NoError(t, err) + + shardsNum := int32(2) + prefix := "TestQueryTask_all" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + boolField := "bool" + int32Field := "int32" + int64Field := "int64" + floatField := "float" + doubleField := "double" + floatVecField := "fvec" + binaryVecField := "bvec" + fieldsLen := len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField}) + dim := 128 + expr := fmt.Sprintf("%s > 0", int64Field) + hitNum := 10 + + schema := constructCollectionSchemaWithAllType( + boolField, int32Field, int64Field, floatField, doubleField, + floatVecField, binaryVecField, dim, collectionName) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createColT := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + assert.NoError(t, createColT.OnEnqueue()) + assert.NoError(t, createColT.PreExecute(ctx)) + assert.NoError(t, createColT.Execute(ctx)) + assert.NoError(t, createColT.PostExecute(ctx)) + + dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) + query := newMockGetChannelsService() + factory := newSimpleMockMsgStreamFactory() + chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) + defer chMgr.removeAllDMLStream() + defer chMgr.removeAllDQLStream() + + collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) + assert.NoError(t, err) + + qc := NewQueryCoordMock() + qc.Start() + defer qc.Stop() + status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + MsgID: 0, + Timestamp: 0, + SourceID: Params.ProxyID, + }, + DbID: 0, + CollectionID: collectionID, + Schema: nil, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) + + task := &queryTask{ + Condition: NewTaskCondition(ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + MsgID: 0, + Timestamp: 0, + SourceID: Params.ProxyID, + }, + ResultChannelID: strconv.Itoa(int(Params.ProxyID)), + DbID: 0, + CollectionID: collectionID, + PartitionIDs: nil, + SerializedExprPlan: nil, + OutputFieldsId: make([]int64, fieldsLen), + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + }, + ctx: ctx, + resultBuf: make(chan []*internalpb.RetrieveResults), + result: &milvuspb.QueryResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + FieldsData: nil, + }, + query: &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + MsgID: 0, + Timestamp: 0, + SourceID: Params.ProxyID, + }, + DbName: dbName, + CollectionName: collectionName, + Expr: expr, + OutputFields: nil, + PartitionNames: nil, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + }, + chMgr: chMgr, + qc: qc, + ids: nil, + } + for i := 0; i < fieldsLen; i++ { + task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i) + } + + // simple mock for query node + // TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream? + + err = chMgr.createDQLStream(collectionID) + assert.NoError(t, err) + stream, err := chMgr.getDQLStream(collectionID) + assert.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + consumeCtx, cancel := context.WithCancel(ctx) + go func() { + defer wg.Done() + for { + select { + case <-consumeCtx.Done(): + return + case pack := <-stream.Chan(): + for _, msg := range pack.Msgs { + _, ok := msg.(*msgstream.RetrieveMsg) + assert.True(t, ok) + // TODO(dragondriver): construct result according to the request + + result1 := &internalpb.RetrieveResults{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_RetrieveResult, + MsgID: 0, + Timestamp: 0, + SourceID: 0, + }, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + ResultChannelID: strconv.Itoa(int(Params.ProxyID)), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: generateInt64Array(hitNum), + }, + }, + }, + FieldsData: make([]*schemapb.FieldData, fieldsLen), + SealedSegmentIDsRetrieved: nil, + ChannelIDsRetrieved: nil, + GlobalSealedSegmentIDs: nil, + } + + result1.FieldsData[0] = &schemapb.FieldData{ + Type: schemapb.DataType_Bool, + FieldName: boolField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: generateBoolArray(hitNum), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 0, + } + + result1.FieldsData[1] = &schemapb.FieldData{ + Type: schemapb.DataType_Int32, + FieldName: int32Field, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: generateInt32Array(hitNum), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 1, + } + + result1.FieldsData[2] = &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: int64Field, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: generateInt64Array(hitNum), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 2, + } + + result1.FieldsData[3] = &schemapb.FieldData{ + Type: schemapb.DataType_Float, + FieldName: floatField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: generateFloat32Array(hitNum), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 3, + } + + result1.FieldsData[4] = &schemapb.FieldData{ + Type: schemapb.DataType_Double, + FieldName: doubleField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: generateFloat64Array(hitNum), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 4, + } + + result1.FieldsData[5] = &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: doubleField, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: generateFloatVectors(hitNum, dim), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 5, + } + + result1.FieldsData[6] = &schemapb.FieldData{ + Type: schemapb.DataType_BinaryVector, + FieldName: doubleField, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: generateBinaryVectors(hitNum, dim), + }, + }, + }, + FieldId: common.StartOfUserFieldID + 6, + } + + // send search result + task.resultBuf <- []*internalpb.RetrieveResults{result1} + } + } + } + }() + + assert.NoError(t, task.OnEnqueue()) + assert.NoError(t, task.PreExecute(ctx)) + assert.NoError(t, task.Execute(ctx)) + assert.NoError(t, task.PostExecute(ctx)) + + cancel() + wg.Wait() +}