diff --git a/internal/proxy/collection_task_test.go b/internal/proxy/collection_task_test.go deleted file mode 100644 index 12da8e82fa..0000000000 --- a/internal/proxy/collection_task_test.go +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. - -package proxy - -import ( - "context" - "strconv" - "testing" - "time" - - "github.com/milvus-io/milvus/internal/proto/schemapb" - - "github.com/milvus-io/milvus/internal/proto/commonpb" - - "github.com/milvus-io/milvus/internal/util/uniquegenerator" - - "github.com/stretchr/testify/assert" - - "github.com/golang/protobuf/proto" - - "github.com/milvus-io/milvus/internal/util/funcutil" - - "github.com/milvus-io/milvus/internal/proto/milvuspb" -) - -func TestCreateCollectionTask(t *testing.T) { - Params.Init() - - rc := NewRootCoordMock() - ctx := context.Background() - shardsNum := int32(2) - prefix := "TestCreateCollectionTask" - dbName := "" - collectionName := prefix + funcutil.GenRandomStr() - int64Field := "int64" - floatVecField := "fvec" - dim := 128 - - schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) - var marshaledSchema []byte - marshaledSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - - task := &createCollectionTask{ - Condition: NewTaskCondition(ctx), - CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - Schema: marshaledSchema, - ShardsNum: shardsNum, - }, - ctx: ctx, - rootCoord: rc, - result: nil, - schema: nil, - } - - t.Run("on enqueue", func(t *testing.T) { - err := task.OnEnqueue() - assert.NoError(t, err) - assert.Equal(t, commonpb.MsgType_CreateCollection, task.Type()) - }) - - t.Run("ctx", func(t *testing.T) { - traceCtx := task.TraceCtx() - assert.NotNil(t, traceCtx) - }) - - t.Run("id", func(t *testing.T) { - id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) - task.SetID(id) - assert.Equal(t, id, task.ID()) - }) - - t.Run("name", func(t *testing.T) { - assert.Equal(t, CreateCollectionTaskName, task.Name()) - }) - - t.Run("ts", func(t *testing.T) { - ts := Timestamp(time.Now().UnixNano()) - task.SetTs(ts) - assert.Equal(t, ts, task.BeginTs()) - assert.Equal(t, ts, task.EndTs()) - }) - - t.Run("process task", func(t *testing.T) { - var err error - - err = task.PreExecute(ctx) - assert.NoError(t, err) - - err = task.Execute(ctx) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode) - - // recreate -> fail - err = task.Execute(ctx) - assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, task.result.ErrorCode) - - err = task.PostExecute(ctx) - assert.NoError(t, err) - }) - - t.Run("PreExecute", func(t *testing.T) { - var err error - - err = task.PreExecute(ctx) - assert.NoError(t, err) - - task.Schema = []byte{0x1, 0x2, 0x3, 0x4} - err = task.PreExecute(ctx) - assert.Error(t, err) - task.Schema = marshaledSchema - - task.ShardsNum = Params.MaxShardNum + 1 - err = task.PreExecute(ctx) - assert.Error(t, err) - task.ShardsNum = shardsNum - - reqBackup := proto.Clone(task.CreateCollectionRequest).(*milvuspb.CreateCollectionRequest) - schemaBackup := proto.Clone(schema).(*schemapb.CollectionSchema) - - schemaWithTooManyFields := &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: make([]*schemapb.FieldSchema, Params.MaxFieldNum+1), - } - marshaledSchemaWithTooManyFields, err := proto.Marshal(schemaWithTooManyFields) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = marshaledSchemaWithTooManyFields - err = task.PreExecute(ctx) - assert.Error(t, err) - - task.CreateCollectionRequest = reqBackup - - // ValidateCollectionName - - schema.Name = " " // empty - emptyNameSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = emptyNameSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - schema.Name = prefix - for i := 0; i < int(Params.MaxNameLength); i++ { - schema.Name += strconv.Itoa(i % 10) - } - tooLongNameSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = tooLongNameSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - schema.Name = "$" // invalid first char - invalidFirstCharSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = invalidFirstCharSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - // ValidateDuplicatedFieldName - schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) - schema.Fields = append(schema.Fields, schema.Fields[0]) - duplicatedFieldsSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = duplicatedFieldsSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - // ValidatePrimaryKey - schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) - for idx := range schema.Fields { - schema.Fields[idx].IsPrimaryKey = false - } - noPrimaryFieldsSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = noPrimaryFieldsSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - // ValidateFieldName - schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) - for idx := range schema.Fields { - schema.Fields[idx].Name = "$" - } - invalidFieldNameSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = invalidFieldNameSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - // ValidateVectorField - schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) - for idx := range schema.Fields { - if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || - schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { - schema.Fields[idx].TypeParams = nil - } - } - noDimSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = noDimSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) - for idx := range schema.Fields { - if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || - schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { - schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "not int", - }, - } - } - } - dimNotIntSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = dimNotIntSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) - for idx := range schema.Fields { - if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || - schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { - schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(int(Params.MaxDimension) + 1), - }, - } - } - } - tooLargeDimSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = tooLargeDimSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - - schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) - schema.Fields[1].DataType = schemapb.DataType_BinaryVector - schema.Fields[1].TypeParams = []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: strconv.Itoa(int(Params.MaxDimension) + 1), - }, - } - binaryTooLargeDimSchema, err := proto.Marshal(schema) - assert.NoError(t, err) - task.CreateCollectionRequest.Schema = binaryTooLargeDimSchema - err = task.PreExecute(ctx) - assert.Error(t, err) - }) -} diff --git a/internal/proxy/rootcoord_mock_test.go b/internal/proxy/rootcoord_mock_test.go index dd3806221c..48ca8b427b 100644 --- a/internal/proxy/rootcoord_mock_test.go +++ b/internal/proxy/rootcoord_mock_test.go @@ -137,6 +137,15 @@ func (coord *RootCoordMock) GetComponentStates(ctx context.Context) (*internalpb } func (coord *RootCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.StringResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + }, nil + } return &milvuspb.StringResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -151,6 +160,15 @@ func (coord *RootCoordMock) Register() error { } func (coord *RootCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.StringResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + }, nil + } return &milvuspb.StringResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -161,6 +179,13 @@ func (coord *RootCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.S } func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } coord.collMtx.Lock() defer coord.collMtx.Unlock() @@ -228,6 +253,13 @@ func (coord *RootCoordMock) CreateCollection(ctx context.Context, req *milvuspb. } func (coord *RootCoordMock) DropCollection(ctx context.Context, req *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } coord.collMtx.Lock() defer coord.collMtx.Unlock() @@ -255,6 +287,16 @@ func (coord *RootCoordMock) DropCollection(ctx context.Context, req *milvuspb.Dr } func (coord *RootCoordMock) HasCollection(ctx context.Context, req *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.BoolResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + Value: false, + }, nil + } coord.collMtx.RLock() defer coord.collMtx.RUnlock() @@ -270,6 +312,17 @@ func (coord *RootCoordMock) HasCollection(ctx context.Context, req *milvuspb.Has } func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.DescribeCollectionResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + Schema: nil, + CollectionID: 0, + }, nil + } coord.collMtx.RLock() defer coord.collMtx.RUnlock() @@ -300,6 +353,16 @@ func (coord *RootCoordMock) DescribeCollection(ctx context.Context, req *milvusp } func (coord *RootCoordMock) ShowCollections(ctx context.Context, req *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + CollectionNames: nil, + }, nil + } coord.collMtx.RLock() defer coord.collMtx.RUnlock() @@ -333,6 +396,13 @@ func (coord *RootCoordMock) ShowCollections(ctx context.Context, req *milvuspb.S } func (coord *RootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } coord.collMtx.RLock() defer coord.collMtx.RUnlock() @@ -372,6 +442,13 @@ func (coord *RootCoordMock) CreatePartition(ctx context.Context, req *milvuspb.C } func (coord *RootCoordMock) DropPartition(ctx context.Context, req *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } coord.collMtx.RLock() defer coord.collMtx.RUnlock() @@ -404,6 +481,16 @@ func (coord *RootCoordMock) DropPartition(ctx context.Context, req *milvuspb.Dro } func (coord *RootCoordMock) HasPartition(ctx context.Context, req *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.BoolResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + Value: false, + }, nil + } coord.collMtx.RLock() defer coord.collMtx.RUnlock() @@ -432,6 +519,17 @@ func (coord *RootCoordMock) HasPartition(ctx context.Context, req *milvuspb.HasP } func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.ShowPartitionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("rootcoord is not healthy, state code = %s", internalpb.StateCode_name[int32(code)]), + }, + PartitionNames: nil, + PartitionIDs: nil, + }, nil + } coord.collMtx.RLock() defer coord.collMtx.RUnlock() @@ -477,6 +575,13 @@ func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.Sh } func (coord *RootCoordMock) CreateIndex(ctx context.Context, req *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", @@ -484,6 +589,16 @@ func (coord *RootCoordMock) CreateIndex(ctx context.Context, req *milvuspb.Creat } func (coord *RootCoordMock) DescribeIndex(ctx context.Context, req *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.DescribeIndexResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + IndexDescriptions: nil, + }, nil + } return &milvuspb.DescribeIndexResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -494,6 +609,13 @@ func (coord *RootCoordMock) DescribeIndex(ctx context.Context, req *milvuspb.Des } func (coord *RootCoordMock) DropIndex(ctx context.Context, req *milvuspb.DropIndexRequest) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", @@ -501,6 +623,17 @@ func (coord *RootCoordMock) DropIndex(ctx context.Context, req *milvuspb.DropInd } func (coord *RootCoordMock) AllocTimestamp(ctx context.Context, req *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &rootcoordpb.AllocTimestampResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + Timestamp: 0, + Count: 0, + }, nil + } coord.lastTsMtx.Lock() defer coord.lastTsMtx.Unlock() @@ -521,6 +654,17 @@ func (coord *RootCoordMock) AllocTimestamp(ctx context.Context, req *rootcoordpb } func (coord *RootCoordMock) AllocID(ctx context.Context, req *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &rootcoordpb.AllocIDResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + ID: 0, + Count: 0, + }, nil + } begin, _ := uniquegenerator.GetUniqueIntGeneratorIns().GetInts(int(req.Count)) return &rootcoordpb.AllocIDResponse{ Status: &commonpb.Status{ @@ -533,6 +677,13 @@ func (coord *RootCoordMock) AllocID(ctx context.Context, req *rootcoordpb.AllocI } func (coord *RootCoordMock) UpdateChannelTimeTick(ctx context.Context, req *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", @@ -540,6 +691,16 @@ func (coord *RootCoordMock) UpdateChannelTimeTick(ctx context.Context, req *inte } func (coord *RootCoordMock) DescribeSegment(ctx context.Context, req *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.DescribeSegmentResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + IndexID: 0, + }, nil + } return &milvuspb.DescribeSegmentResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -552,6 +713,16 @@ func (coord *RootCoordMock) DescribeSegment(ctx context.Context, req *milvuspb.D } func (coord *RootCoordMock) ShowSegments(ctx context.Context, req *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &milvuspb.ShowSegmentsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, + SegmentIDs: nil, + }, nil + } return &milvuspb.ShowSegmentsResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, @@ -562,6 +733,13 @@ func (coord *RootCoordMock) ShowSegments(ctx context.Context, req *milvuspb.Show } func (coord *RootCoordMock) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", @@ -569,6 +747,13 @@ func (coord *RootCoordMock) ReleaseDQLMessageStream(ctx context.Context, in *pro } func (coord *RootCoordMock) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("state code = %s", internalpb.StateCode_name[int32(code)]), + }, nil + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, Reason: "", @@ -576,6 +761,17 @@ func (coord *RootCoordMock) SegmentFlushCompleted(ctx context.Context, in *datap } func (coord *RootCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + code := coord.state.Load().(internalpb.StateCode) + if code != internalpb.StateCode_Healthy { + + return &milvuspb.GetMetricsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "failed", + }, + Response: "", + }, nil + } rootCoordTopology := metricsinfo.RootCoordTopology{ Self: metricsinfo.RootCoordInfos{ BaseComponentInfos: metricsinfo.BaseComponentInfos{ diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 2812c5e57d..f3d1fc0acd 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -2,9 +2,9 @@ package proxy import ( "context" - "fmt" "strconv" "testing" + "time" "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus/internal/log" @@ -12,6 +12,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/uniquegenerator" "github.com/stretchr/testify/assert" ) @@ -59,28 +61,6 @@ func constructCollectionSchema( } } -func constructCreateCollectionRequest( - schema *schemapb.CollectionSchema, - dbName, collectionName string, - shardsNum int32, -) *milvuspb.CreateCollectionRequest { - bs, err := proto.Marshal(schema) - if err != nil { - panic( - fmt.Sprintf( - "failed to marshal collection schema, schema: %v, error: %v", - schema, - err)) - } - return &milvuspb.CreateCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - Schema: bs, - ShardsNum: shardsNum, - } -} - func TestGetNumRowsOfScalarField(t *testing.T) { cases := []struct { datas interface{} @@ -682,3 +662,649 @@ func TestSearchTask(t *testing.T) { // TODO, add decode result, reduce result test } + +func TestCreateCollectionTask(t *testing.T) { + Params.Init() + + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + shardsNum := int32(2) + prefix := "TestCreateCollectionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + int64Field := "int64" + floatVecField := "fvec" + dim := 128 + + schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) + var marshaledSchema []byte + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + task := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + t.Run("on enqueue", func(t *testing.T) { + err := task.OnEnqueue() + assert.NoError(t, err) + assert.Equal(t, commonpb.MsgType_CreateCollection, task.Type()) + }) + + t.Run("ctx", func(t *testing.T) { + traceCtx := task.TraceCtx() + assert.NotNil(t, traceCtx) + }) + + t.Run("id", func(t *testing.T) { + id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()) + task.SetID(id) + assert.Equal(t, id, task.ID()) + }) + + t.Run("name", func(t *testing.T) { + assert.Equal(t, CreateCollectionTaskName, task.Name()) + }) + + t.Run("ts", func(t *testing.T) { + ts := Timestamp(time.Now().UnixNano()) + task.SetTs(ts) + assert.Equal(t, ts, task.BeginTs()) + assert.Equal(t, ts, task.EndTs()) + }) + + t.Run("process task", func(t *testing.T) { + var err error + + err = task.PreExecute(ctx) + assert.NoError(t, err) + + err = task.Execute(ctx) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode) + + // recreate -> fail + err = task.Execute(ctx) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, task.result.ErrorCode) + + err = task.PostExecute(ctx) + assert.NoError(t, err) + }) + + t.Run("PreExecute", func(t *testing.T) { + var err error + + err = task.PreExecute(ctx) + assert.NoError(t, err) + + task.Schema = []byte{0x1, 0x2, 0x3, 0x4} + err = task.PreExecute(ctx) + assert.Error(t, err) + task.Schema = marshaledSchema + + task.ShardsNum = Params.MaxShardNum + 1 + err = task.PreExecute(ctx) + assert.Error(t, err) + task.ShardsNum = shardsNum + + reqBackup := proto.Clone(task.CreateCollectionRequest).(*milvuspb.CreateCollectionRequest) + schemaBackup := proto.Clone(schema).(*schemapb.CollectionSchema) + + schemaWithTooManyFields := &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: make([]*schemapb.FieldSchema, Params.MaxFieldNum+1), + } + marshaledSchemaWithTooManyFields, err := proto.Marshal(schemaWithTooManyFields) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = marshaledSchemaWithTooManyFields + err = task.PreExecute(ctx) + assert.Error(t, err) + + task.CreateCollectionRequest = reqBackup + + // ValidateCollectionName + + schema.Name = " " // empty + emptyNameSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = emptyNameSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema.Name = prefix + for i := 0; i < int(Params.MaxNameLength); i++ { + schema.Name += strconv.Itoa(i % 10) + } + tooLongNameSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = tooLongNameSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema.Name = "$" // invalid first char + invalidFirstCharSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = invalidFirstCharSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidateDuplicatedFieldName + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + schema.Fields = append(schema.Fields, schema.Fields[0]) + duplicatedFieldsSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = duplicatedFieldsSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidatePrimaryKey + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + schema.Fields[idx].IsPrimaryKey = false + } + noPrimaryFieldsSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = noPrimaryFieldsSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidateFieldName + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + schema.Fields[idx].Name = "$" + } + invalidFieldNameSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = invalidFieldNameSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + // ValidateVectorField + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || + schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { + schema.Fields[idx].TypeParams = nil + } + } + noDimSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = noDimSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || + schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { + schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "not int", + }, + } + } + } + dimNotIntSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = dimNotIntSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + for idx := range schema.Fields { + if schema.Fields[idx].DataType == schemapb.DataType_FloatVector || + schema.Fields[idx].DataType == schemapb.DataType_BinaryVector { + schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: strconv.Itoa(int(Params.MaxDimension) + 1), + }, + } + } + } + tooLargeDimSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = tooLargeDimSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + + schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema) + schema.Fields[1].DataType = schemapb.DataType_BinaryVector + schema.Fields[1].TypeParams = []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: strconv.Itoa(int(Params.MaxDimension) + 1), + }, + } + binaryTooLargeDimSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + task.CreateCollectionRequest.Schema = binaryTooLargeDimSchema + err = task.PreExecute(ctx) + assert.Error(t, err) + }) +} + +func TestDropCollectionTask(t *testing.T) { + Params.Init() + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + InitMetaCache(rc) + + master := newMockGetChannelsService() + query := newMockGetChannelsService() + factory := newSimpleMockMsgStreamFactory() + channelMgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory) + defer channelMgr.removeAllDMLStream() + + prefix := "TestDropCollectionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + + shardsNum := int32(2) + int64Field := "int64" + floatVecField := "fvec" + dim := 128 + + schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createColReq := &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropCollection, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + } + + //CreateCollection + task := &dropCollectionTask{ + Condition: NewTaskCondition(ctx), + DropCollectionRequest: &milvuspb.DropCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropCollection, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + }, + ctx: ctx, + chMgr: channelMgr, + rootCoord: rc, + result: nil, + } + task.PreExecute(ctx) + + assert.Equal(t, commonpb.MsgType_DropCollection, task.Type()) + assert.Equal(t, UniqueID(100), task.ID()) + assert.Equal(t, Timestamp(100), task.BeginTs()) + assert.Equal(t, Timestamp(100), task.EndTs()) + assert.Equal(t, Params.ProxyID, task.GetBase().GetSourceID()) + // missing collectionID in globalMetaCache + err = task.Execute(ctx) + assert.NotNil(t, err) + // createCollection in RootCood and fill GlobalMetaCache + rc.CreateCollection(ctx, createColReq) + globalMetaCache.GetCollectionID(ctx, collectionName) + + // success to drop collection + err = task.Execute(ctx) + assert.Nil(t, err) + + // illegal name + task.CollectionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + task.CollectionName = collectionName + err = task.PreExecute(ctx) + assert.Nil(t, err) + +} + +func TestHasCollectionTask(t *testing.T) { + Params.Init() + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + InitMetaCache(rc) + prefix := "TestHasCollectionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + + shardsNum := int32(2) + int64Field := "int64" + floatVecField := "fvec" + dim := 128 + + schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createColReq := &milvuspb.CreateCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropCollection, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + } + + //CreateCollection + task := &hasCollectionTask{ + Condition: NewTaskCondition(ctx), + HasCollectionRequest: &milvuspb.HasCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_HasCollection, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + } + task.PreExecute(ctx) + + assert.Equal(t, commonpb.MsgType_HasCollection, task.Type()) + assert.Equal(t, UniqueID(100), task.ID()) + assert.Equal(t, Timestamp(100), task.BeginTs()) + assert.Equal(t, Timestamp(100), task.EndTs()) + assert.Equal(t, Params.ProxyID, task.GetBase().GetSourceID()) + // missing collectionID in globalMetaCache + err = task.Execute(ctx) + assert.Nil(t, err) + assert.Equal(t, false, task.result.Value) + // createCollection in RootCood and fill GlobalMetaCache + rc.CreateCollection(ctx, createColReq) + globalMetaCache.GetCollectionID(ctx, collectionName) + + // success to drop collection + err = task.Execute(ctx) + assert.Nil(t, err) + assert.Equal(t, true, task.result.Value) + + // illegal name + task.CollectionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + rc.updateState(internalpb.StateCode_Abnormal) + task.CollectionName = collectionName + err = task.PreExecute(ctx) + assert.Nil(t, err) + err = task.Execute(ctx) + assert.NotNil(t, err) + +} + +func TestDescribeCollectionTask(t *testing.T) { + Params.Init() + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + InitMetaCache(rc) + prefix := "TestDescribeCollectionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + + //CreateCollection + task := &describeCollectionTask{ + Condition: NewTaskCondition(ctx), + DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DescribeCollection, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + } + task.PreExecute(ctx) + + assert.Equal(t, commonpb.MsgType_DescribeCollection, task.Type()) + assert.Equal(t, UniqueID(100), task.ID()) + assert.Equal(t, Timestamp(100), task.BeginTs()) + assert.Equal(t, Timestamp(100), task.EndTs()) + assert.Equal(t, Params.ProxyID, task.GetBase().GetSourceID()) + // missing collectionID in globalMetaCache + err := task.Execute(ctx) + assert.Nil(t, err) + + // illegal name + task.CollectionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + rc.Stop() + task.CollectionName = collectionName + err = task.PreExecute(ctx) + assert.Nil(t, err) + err = task.Execute(ctx) + assert.Nil(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.Status.ErrorCode) +} + +func TestCreatePartitionTask(t *testing.T) { + Params.Init() + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + prefix := "TestCreatePartitionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + partitionName := prefix + funcutil.GenRandomStr() + + task := &createPartitionTask{ + Condition: NewTaskCondition(ctx), + CreatePartitionRequest: &milvuspb.CreatePartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreatePartition, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + } + task.PreExecute(ctx) + + assert.Equal(t, commonpb.MsgType_CreatePartition, task.Type()) + assert.Equal(t, UniqueID(100), task.ID()) + assert.Equal(t, Timestamp(100), task.BeginTs()) + assert.Equal(t, Timestamp(100), task.EndTs()) + assert.Equal(t, Params.ProxyID, task.GetBase().GetSourceID()) + err := task.Execute(ctx) + assert.NotNil(t, err) + + task.CollectionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + task.CollectionName = collectionName + task.PartitionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) +} + +func TestDropPartitionTask(t *testing.T) { + Params.Init() + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + prefix := "TestDropPartitionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + partitionName := prefix + funcutil.GenRandomStr() + + task := &dropPartitionTask{ + Condition: NewTaskCondition(ctx), + DropPartitionRequest: &milvuspb.DropPartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropPartition, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + } + task.PreExecute(ctx) + + assert.Equal(t, commonpb.MsgType_DropPartition, task.Type()) + assert.Equal(t, UniqueID(100), task.ID()) + assert.Equal(t, Timestamp(100), task.BeginTs()) + assert.Equal(t, Timestamp(100), task.EndTs()) + assert.Equal(t, Params.ProxyID, task.GetBase().GetSourceID()) + err := task.Execute(ctx) + assert.NotNil(t, err) + + task.CollectionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + task.CollectionName = collectionName + task.PartitionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) +} + +func TestHasPartitionTask(t *testing.T) { + Params.Init() + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + prefix := "TestHasPartitionTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + partitionName := prefix + funcutil.GenRandomStr() + + task := &hasPartitionTask{ + Condition: NewTaskCondition(ctx), + HasPartitionRequest: &milvuspb.HasPartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_HasPartition, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + } + task.PreExecute(ctx) + + assert.Equal(t, commonpb.MsgType_HasPartition, task.Type()) + assert.Equal(t, UniqueID(100), task.ID()) + assert.Equal(t, Timestamp(100), task.BeginTs()) + assert.Equal(t, Timestamp(100), task.EndTs()) + assert.Equal(t, Params.ProxyID, task.GetBase().GetSourceID()) + err := task.Execute(ctx) + assert.NotNil(t, err) + + task.CollectionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + task.CollectionName = collectionName + task.PartitionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) +} + +func TestShowPartitionsTask(t *testing.T) { + Params.Init() + rc := NewRootCoordMock() + rc.Start() + ctx := context.Background() + prefix := "TestShowPartitionsTask" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + partitionName := prefix + funcutil.GenRandomStr() + + task := &showPartitionsTask{ + Condition: NewTaskCondition(ctx), + ShowPartitionsRequest: &milvuspb.ShowPartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ShowPartitions, + MsgID: 100, + Timestamp: 100, + }, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: []string{partitionName}, + Type: milvuspb.ShowType_All, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + } + task.PreExecute(ctx) + + assert.Equal(t, commonpb.MsgType_ShowPartitions, task.Type()) + assert.Equal(t, UniqueID(100), task.ID()) + assert.Equal(t, Timestamp(100), task.BeginTs()) + assert.Equal(t, Timestamp(100), task.EndTs()) + assert.Equal(t, Params.ProxyID, task.GetBase().GetSourceID()) + err := task.Execute(ctx) + assert.NotNil(t, err) + + task.CollectionName = "#0xc0de" + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + task.CollectionName = collectionName + task.ShowPartitionsRequest.Type = milvuspb.ShowType_InMemory + task.PartitionNames = []string{"#0xc0de"} + err = task.PreExecute(ctx) + assert.NotNil(t, err) + + task.CollectionName = collectionName + task.PartitionNames = []string{partitionName} + task.ShowPartitionsRequest.Type = milvuspb.ShowType_InMemory + err = task.Execute(ctx) + assert.NotNil(t, err) + +}