diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 31c39aa0d1..cb62cf0002 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1924,6 +1924,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde req: request, rootCoord: node.rootCoord, indexCoord: node.indexCoord, + queryCoord: node.queryCoord, } method := "CreateIndex" diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 068b1482ee..f3ef2b3bf9 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -869,7 +869,7 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error { return err } - collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, []int64{collID}) + collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, collID) if err != nil { return err } diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index a0f9f0e616..94c533eb41 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -55,6 +55,7 @@ type createIndexTask struct { ctx context.Context rootCoord types.RootCoord indexCoord types.IndexCoord + queryCoord types.QueryCoord result *commonpb.Status isAutoIndex bool @@ -283,7 +284,20 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error { } cit.fieldSchema = field // check index param, not accurate, only some static rules - return cit.parseIndexParams() + err = cit.parseIndexParams() + if err != nil { + return err + } + + loaded, err := isCollectionLoaded(ctx, cit.queryCoord, collID) + if err != nil { + return err + } + + if loaded { + return fmt.Errorf("create index failed, collection is loaded, please release it first") + } + return nil } func (cit *createIndexTask) Execute(ctx context.Context) error { @@ -504,7 +518,7 @@ func (dit *dropIndexTask) PreExecute(ctx context.Context) error { } dit.collectionID = collID - loaded, err := isCollectionLoaded(ctx, dit.queryCoord, []int64{collID}) + loaded, err := isCollectionLoaded(ctx, dit.queryCoord, collID) if err != nil { return err } diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 33862c2b7f..ac2bf5aa20 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -21,14 +21,15 @@ import ( "errors" "testing" - "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/util/typeutil" ) func TestGetIndexStateTask_Execute(t *testing.T) { @@ -203,3 +204,87 @@ func TestDropIndexTask_PreExecute(t *testing.T) { assert.Error(t, err) }) } + +func TestCreateIndexTask_PreExecute(t *testing.T) { + collectionName := "collection1" + collectionID := UniqueID(1) + fieldName := newTestSchema().Fields[0].Name + + Params.Init() + ic := newMockIndexCoord() + ctx := context.Background() + + mockCache := newMockCache() + mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { + return collectionID, nil + }) + mockCache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { + return newTestSchema(), nil + }) + globalMetaCache = mockCache + + cit := createIndexTask{ + ctx: ctx, + req: &milvuspb.CreateIndexRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateIndex, + }, + CollectionName: collectionName, + FieldName: fieldName, + }, + indexCoord: ic, + queryCoord: nil, + result: nil, + collectionID: collectionID, + } + + t.Run("normal", func(t *testing.T) { + showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + CollectionIDs: []int64{}, + }, nil + } + qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) + qc.updateState(commonpb.StateCode_Healthy) + cit.queryCoord = qc + + err := cit.PreExecute(ctx) + assert.NoError(t, err) + }) + + t.Run("coll has been loaded", func(t *testing.T) { + showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + CollectionIDs: []int64{collectionID}, + }, nil + } + qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) + qc.updateState(commonpb.StateCode_Healthy) + cit.queryCoord = qc + err := cit.PreExecute(ctx) + assert.Error(t, err) + }) + + t.Run("check load error", func(t *testing.T) { + showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "fail reason", + }, + CollectionIDs: nil, + }, errors.New("error") + } + qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) + qc.updateState(commonpb.StateCode_Healthy) + cit.queryCoord = qc + err := cit.PreExecute(ctx) + assert.Error(t, err) + }) +} diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index c25d598356..323d142722 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -2136,6 +2136,17 @@ func Test_createIndexTask_PreExecute(t *testing.T) { FieldName: fieldName, }, } + showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + CollectionIDs: []int64{}, + }, nil + } + qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) + qc.updateState(commonpb.StateCode_Healthy) + cit.queryCoord = qc t.Run("normal", func(t *testing.T) { cache := newMockCache() diff --git a/internal/proxy/util.go b/internal/proxy/util.go index e3a13a4d18..fe9a01fea6 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -836,7 +836,7 @@ func validateIndexName(indexName string) error { return nil } -func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int64) (bool, error) { +func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collID int64) (bool, error) { // get all loading collections resp, err := qc.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ CollectionIDs: nil, @@ -848,23 +848,18 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int6 return false, errors.New(resp.Status.Reason) } - loaded := false -LOOP: for _, loadedCollID := range resp.GetCollectionIDs() { - for _, collID := range collIDs { - if collID == loadedCollID { - loaded = true - break LOOP - } + if collID == loadedCollID { + return true, nil } } - return loaded, nil + return false, nil } -func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64, partIDs []int64) (bool, error) { +func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, partIDs []int64) (bool, error) { // get all loading collections resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ - CollectionID: collIDs, + CollectionID: collID, PartitionIDs: nil, }) if err != nil { @@ -874,15 +869,12 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64, return false, errors.New(resp.Status.Reason) } - loaded := false -LOOP: for _, loadedPartID := range resp.GetPartitionIDs() { for _, partID := range partIDs { if partID == loadedPartID { - loaded = true - break LOOP + return true, nil } } } - return loaded, nil + return false, nil } diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 8c59f59156..b4aacb83e8 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -825,7 +825,7 @@ func Test_isCollectionIsLoaded(t *testing.T) { } qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) qc.updateState(commonpb.StateCode_Healthy) - loaded, err := isCollectionLoaded(ctx, qc, []int64{collID}) + loaded, err := isCollectionLoaded(ctx, qc, collID) assert.NoError(t, err) assert.True(t, loaded) }) @@ -843,7 +843,7 @@ func Test_isCollectionIsLoaded(t *testing.T) { } qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) qc.updateState(commonpb.StateCode_Healthy) - loaded, err := isCollectionLoaded(ctx, qc, []int64{collID}) + loaded, err := isCollectionLoaded(ctx, qc, collID) assert.Error(t, err) assert.False(t, loaded) }) @@ -861,7 +861,7 @@ func Test_isCollectionIsLoaded(t *testing.T) { } qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) qc.updateState(commonpb.StateCode_Healthy) - loaded, err := isCollectionLoaded(ctx, qc, []int64{collID}) + loaded, err := isCollectionLoaded(ctx, qc, collID) assert.Error(t, err) assert.False(t, loaded) }) diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index 22e1787807..6caf55f89d 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -1398,9 +1398,9 @@ class TestIndexString(TestcaseBase): data = cf.gen_default_list_data(ct.default_nb) collection_w.insert(data=data) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index, index_name="vector_flat") - collection_w.load() index, _ = self.index_wrap.init_index(collection_w.collection, default_string_field_name, default_string_index_params) + collection_w.load() cf.assert_equal_index(index, collection_w.indexes[0]) assert collection_w.num_entities == default_nb diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index 65de4ce794..e19a63ecdb 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -748,7 +748,7 @@ class TestQueryParams(TestcaseBase): expected: verify query result """ # init collection with fields: int64, float, float_vec - collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] + collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=True)[0:2] df = vectors[0] # query with output_fields=["*", float_vector)