diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index e590f9c5b1..bee7e618a0 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1248,7 +1248,6 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest) it := &insertTask{ ctx: ctx, Condition: NewTaskCondition(ctx), - dataCoord: node.dataCoord, req: request, BaseInsertTask: BaseInsertTask{ BaseMsg: msgstream.BaseMsg{ diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 346252d134..ac8e33ce21 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -277,7 +277,23 @@ func (ms *simpleMockMsgStream) AsConsumer(channels []string, subName string) { } func (ms *simpleMockMsgStream) ComputeProduceChannelIndexes(tsMsgs []msgstream.TsMsg) [][]int32 { - return nil + if len(tsMsgs) <= 0 { + return nil + } + reBucketValues := make([][]int32, len(tsMsgs)) + channelNum := uint32(1) + if channelNum == 0 { + return nil + } + for idx, tsMsg := range tsMsgs { + hashValues := tsMsg.HashKeys() + bucketValues := make([]int32, len(hashValues)) + for index, hashValue := range hashValues { + bucketValues[index] = int32(hashValue % channelNum) + } + reBucketValues[idx] = bucketValues + } + return reBucketValues } func (ms *simpleMockMsgStream) SetRepackFunc(repackFunc msgstream.RepackFunc) { diff --git a/internal/proxy/rootcoord_mock_test.go b/internal/proxy/rootcoord_mock_test.go index eccf76d27f..ef2373f3f2 100644 --- a/internal/proxy/rootcoord_mock_test.go +++ b/internal/proxy/rootcoord_mock_test.go @@ -353,8 +353,14 @@ func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvusp coord.collMtx.RLock() defer coord.collMtx.RUnlock() + var collID UniqueID + usingID := false + if req.CollectionName == "" { + usingID = true + } + collID, exist := coord.collName2ID[req.CollectionName] - if !exist { + if !exist && !usingID { return &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_CollectionNotExists, @@ -363,6 +369,10 @@ func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvusp }, nil } + if usingID { + collID = req.CollectionID + } + meta := coord.collID2Meta[collID] if meta.shardsNum == 0 { meta.shardsNum = int32(len(meta.virtualChannelNames)) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 02e72232b8..9ab4b5c24f 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -113,7 +113,6 @@ type insertTask struct { ctx context.Context result *milvuspb.MutationResult - dataCoord types.DataCoord rowIDAllocator *allocator.IDAllocator segIDAssigner *SegIDAssigner chMgr channelsMgr diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 476a38841a..60251cd766 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -13,6 +13,8 @@ import ( "testing" "time" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" @@ -2710,3 +2712,270 @@ func TestQueryTask_all(t *testing.T) { cancel() wg.Wait() } + +func TestInsertTask_all(t *testing.T) { + var err error + + Params.Init() + Params.RetrieveResultChannelNames = []string{funcutil.GenRandomStr()} + + rc := NewRootCoordMock() + rc.Start() + defer rc.Stop() + + ctx := context.Background() + + err = InitMetaCache(rc) + assert.NoError(t, err) + + shardsNum := int32(2) + prefix := "TestQueryTask_all" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + partitionName := prefix + funcutil.GenRandomStr() + boolField := "bool" + int32Field := "int32" + int64Field := "int64" + floatField := "float" + doubleField := "double" + floatVecField := "fvec" + binaryVecField := "bvec" + fieldsLen := len([]string{boolField, int32Field, int64Field, floatField, doubleField, floatVecField, binaryVecField}) + dim := 128 + nb := 10 + + schema := constructCollectionSchemaWithAllType( + boolField, int32Field, int64Field, floatField, doubleField, + floatVecField, binaryVecField, dim, collectionName) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createColT := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + assert.NoError(t, createColT.OnEnqueue()) + assert.NoError(t, createColT.PreExecute(ctx)) + assert.NoError(t, createColT.Execute(ctx)) + assert.NoError(t, createColT.PostExecute(ctx)) + + _, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreatePartition, + MsgID: 0, + Timestamp: 0, + SourceID: Params.ProxyID, + }, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, + }) + + collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) + assert.NoError(t, err) + + dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) + query := newMockGetChannelsService() + factory := newSimpleMockMsgStreamFactory() + chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) + defer chMgr.removeAllDMLStream() + defer chMgr.removeAllDQLStream() + + err = chMgr.createDMLMsgStream(collectionID) + assert.NoError(t, err) + pchans, err := chMgr.getChannels(collectionID) + assert.NoError(t, err) + + interval := time.Millisecond * 10 + tso := newMockTsoAllocator() + + ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso) + _ = ticker.start() + defer ticker.close() + + idAllocator, err := allocator.NewIDAllocator(ctx, rc, Params.ProxyID) + assert.NoError(t, err) + _ = idAllocator.Start() + defer idAllocator.Close() + + segAllocator, err := NewSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1) + assert.NoError(t, err) + segAllocator.Init() + _ = segAllocator.Start() + defer segAllocator.Close() + + hash := generateHashKeys(nb) + task := &insertTask{ + BaseInsertTask: BaseInsertTask{ + BaseMsg: msgstream.BaseMsg{ + HashValues: hash, + }, + InsertRequest: internalpb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 0, + }, + CollectionName: collectionName, + PartitionName: partitionName, + }, + }, + req: &milvuspb.InsertRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Insert, + MsgID: 0, + Timestamp: 0, + SourceID: Params.ProxyID, + }, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, + FieldsData: make([]*schemapb.FieldData, fieldsLen), + HashKeys: hash, + NumRows: uint32(nb), + }, + Condition: NewTaskCondition(ctx), + ctx: ctx, + result: &milvuspb.MutationResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + IDs: nil, + SuccIndex: nil, + ErrIndex: nil, + Acknowledged: false, + InsertCnt: 0, + DeleteCnt: 0, + UpsertCnt: 0, + Timestamp: 0, + }, + rowIDAllocator: idAllocator, + segIDAssigner: segAllocator, + chMgr: chMgr, + chTicker: ticker, + vChannels: nil, + pChannels: nil, + schema: nil, + } + + task.req.FieldsData[0] = &schemapb.FieldData{ + Type: schemapb.DataType_Bool, + FieldName: boolField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: generateBoolArray(nb), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 0, + } + + task.req.FieldsData[1] = &schemapb.FieldData{ + Type: schemapb.DataType_Int32, + FieldName: int32Field, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: generateInt32Array(nb), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 1, + } + + task.req.FieldsData[2] = &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: int64Field, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: generateInt64Array(nb), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 2, + } + + task.req.FieldsData[3] = &schemapb.FieldData{ + Type: schemapb.DataType_Float, + FieldName: floatField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: generateFloat32Array(nb), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 3, + } + + task.req.FieldsData[4] = &schemapb.FieldData{ + Type: schemapb.DataType_Double, + FieldName: doubleField, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: generateFloat64Array(nb), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 4, + } + + task.req.FieldsData[5] = &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: doubleField, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: generateFloatVectors(nb, dim), + }, + }, + }, + }, + FieldId: common.StartOfUserFieldID + 5, + } + + task.req.FieldsData[6] = &schemapb.FieldData{ + Type: schemapb.DataType_BinaryVector, + FieldName: doubleField, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: generateBinaryVectors(nb, dim), + }, + }, + }, + FieldId: common.StartOfUserFieldID + 6, + } + + assert.NoError(t, task.OnEnqueue()) + assert.NoError(t, task.PreExecute(ctx)) + assert.NoError(t, task.Execute(ctx)) + assert.NoError(t, task.PostExecute(ctx)) +}