Forbid createIndex if collection loaded before (#20100)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
bigsheeper 2022-10-27 13:05:31 +08:00 committed by GitHub
parent 3ff0112e49
commit b074f530e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 129 additions and 26 deletions

View File

@ -1924,6 +1924,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
req: request, req: request,
rootCoord: node.rootCoord, rootCoord: node.rootCoord,
indexCoord: node.indexCoord, indexCoord: node.indexCoord,
queryCoord: node.queryCoord,
} }
method := "CreateIndex" method := "CreateIndex"

View File

@ -869,7 +869,7 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error {
return err return err
} }
collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, []int64{collID}) collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, collID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -55,6 +55,7 @@ type createIndexTask struct {
ctx context.Context ctx context.Context
rootCoord types.RootCoord rootCoord types.RootCoord
indexCoord types.IndexCoord indexCoord types.IndexCoord
queryCoord types.QueryCoord
result *commonpb.Status result *commonpb.Status
isAutoIndex bool isAutoIndex bool
@ -283,7 +284,20 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
} }
cit.fieldSchema = field cit.fieldSchema = field
// check index param, not accurate, only some static rules // check index param, not accurate, only some static rules
return cit.parseIndexParams() err = cit.parseIndexParams()
if err != nil {
return err
}
loaded, err := isCollectionLoaded(ctx, cit.queryCoord, collID)
if err != nil {
return err
}
if loaded {
return fmt.Errorf("create index failed, collection is loaded, please release it first")
}
return nil
} }
func (cit *createIndexTask) Execute(ctx context.Context) error { func (cit *createIndexTask) Execute(ctx context.Context) error {
@ -504,7 +518,7 @@ func (dit *dropIndexTask) PreExecute(ctx context.Context) error {
} }
dit.collectionID = collID dit.collectionID = collID
loaded, err := isCollectionLoaded(ctx, dit.queryCoord, []int64{collID}) loaded, err := isCollectionLoaded(ctx, dit.queryCoord, collID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -21,14 +21,15 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/internal/util/typeutil"
) )
func TestGetIndexStateTask_Execute(t *testing.T) { func TestGetIndexStateTask_Execute(t *testing.T) {
@ -203,3 +204,87 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
}) })
} }
func TestCreateIndexTask_PreExecute(t *testing.T) {
collectionName := "collection1"
collectionID := UniqueID(1)
fieldName := newTestSchema().Fields[0].Name
Params.Init()
ic := newMockIndexCoord()
ctx := context.Background()
mockCache := newMockCache()
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return collectionID, nil
})
mockCache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
return newTestSchema(), nil
})
globalMetaCache = mockCache
cit := createIndexTask{
ctx: ctx,
req: &milvuspb.CreateIndexRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateIndex,
},
CollectionName: collectionName,
FieldName: fieldName,
},
indexCoord: ic,
queryCoord: nil,
result: nil,
collectionID: collectionID,
}
t.Run("normal", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
err := cit.PreExecute(ctx)
assert.NoError(t, err)
})
t.Run("coll has been loaded", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{collectionID},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
err := cit.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("check load error", func(t *testing.T) {
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
CollectionIDs: nil,
}, errors.New("error")
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
err := cit.PreExecute(ctx)
assert.Error(t, err)
})
}

View File

@ -2136,6 +2136,17 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
FieldName: fieldName, FieldName: fieldName,
}, },
} }
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
CollectionIDs: []int64{},
}, nil
}
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy)
cit.queryCoord = qc
t.Run("normal", func(t *testing.T) { t.Run("normal", func(t *testing.T) {
cache := newMockCache() cache := newMockCache()

View File

@ -836,7 +836,7 @@ func validateIndexName(indexName string) error {
return nil return nil
} }
func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int64) (bool, error) { func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collID int64) (bool, error) {
// get all loading collections // get all loading collections
resp, err := qc.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ resp, err := qc.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
CollectionIDs: nil, CollectionIDs: nil,
@ -848,23 +848,18 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int6
return false, errors.New(resp.Status.Reason) return false, errors.New(resp.Status.Reason)
} }
loaded := false
LOOP:
for _, loadedCollID := range resp.GetCollectionIDs() { for _, loadedCollID := range resp.GetCollectionIDs() {
for _, collID := range collIDs { if collID == loadedCollID {
if collID == loadedCollID { return true, nil
loaded = true
break LOOP
}
} }
} }
return loaded, nil return false, nil
} }
func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64, partIDs []int64) (bool, error) { func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, partIDs []int64) (bool, error) {
// get all loading collections // get all loading collections
resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
CollectionID: collIDs, CollectionID: collID,
PartitionIDs: nil, PartitionIDs: nil,
}) })
if err != nil { if err != nil {
@ -874,15 +869,12 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64,
return false, errors.New(resp.Status.Reason) return false, errors.New(resp.Status.Reason)
} }
loaded := false
LOOP:
for _, loadedPartID := range resp.GetPartitionIDs() { for _, loadedPartID := range resp.GetPartitionIDs() {
for _, partID := range partIDs { for _, partID := range partIDs {
if partID == loadedPartID { if partID == loadedPartID {
loaded = true return true, nil
break LOOP
} }
} }
} }
return loaded, nil return false, nil
} }

View File

@ -825,7 +825,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
} }
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy) qc.updateState(commonpb.StateCode_Healthy)
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID}) loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, loaded) assert.True(t, loaded)
}) })
@ -843,7 +843,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
} }
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy) qc.updateState(commonpb.StateCode_Healthy)
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID}) loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.Error(t, err) assert.Error(t, err)
assert.False(t, loaded) assert.False(t, loaded)
}) })
@ -861,7 +861,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
} }
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock)) qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
qc.updateState(commonpb.StateCode_Healthy) qc.updateState(commonpb.StateCode_Healthy)
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID}) loaded, err := isCollectionLoaded(ctx, qc, collID)
assert.Error(t, err) assert.Error(t, err)
assert.False(t, loaded) assert.False(t, loaded)
}) })

View File

@ -1398,9 +1398,9 @@ class TestIndexString(TestcaseBase):
data = cf.gen_default_list_data(ct.default_nb) data = cf.gen_default_list_data(ct.default_nb)
collection_w.insert(data=data) collection_w.insert(data=data)
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index, index_name="vector_flat") collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index, index_name="vector_flat")
collection_w.load()
index, _ = self.index_wrap.init_index(collection_w.collection, default_string_field_name, index, _ = self.index_wrap.init_index(collection_w.collection, default_string_field_name,
default_string_index_params) default_string_index_params)
collection_w.load()
cf.assert_equal_index(index, collection_w.indexes[0]) cf.assert_equal_index(index, collection_w.indexes[0])
assert collection_w.num_entities == default_nb assert collection_w.num_entities == default_nb

View File

@ -748,7 +748,7 @@ class TestQueryParams(TestcaseBase):
expected: verify query result expected: verify query result
""" """
# init collection with fields: int64, float, float_vec # init collection with fields: int64, float, float_vec
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2] collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=True)[0:2]
df = vectors[0] df = vectors[0]
# query with output_fields=["*", float_vector) # query with output_fields=["*", float_vector)