mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 17:48:29 +08:00
Forbid createIndex if collection loaded before (#20100)
Signed-off-by: bigsheeper <yihao.dai@zilliz.com> Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
parent
3ff0112e49
commit
b074f530e6
@ -1924,6 +1924,7 @@ func (node *Proxy) CreateIndex(ctx context.Context, request *milvuspb.CreateInde
|
||||
req: request,
|
||||
rootCoord: node.rootCoord,
|
||||
indexCoord: node.indexCoord,
|
||||
queryCoord: node.queryCoord,
|
||||
}
|
||||
|
||||
method := "CreateIndex"
|
||||
|
||||
@ -869,7 +869,7 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, []int64{collID})
|
||||
collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -55,6 +55,7 @@ type createIndexTask struct {
|
||||
ctx context.Context
|
||||
rootCoord types.RootCoord
|
||||
indexCoord types.IndexCoord
|
||||
queryCoord types.QueryCoord
|
||||
result *commonpb.Status
|
||||
|
||||
isAutoIndex bool
|
||||
@ -283,7 +284,20 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error {
|
||||
}
|
||||
cit.fieldSchema = field
|
||||
// check index param, not accurate, only some static rules
|
||||
return cit.parseIndexParams()
|
||||
err = cit.parseIndexParams()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loaded, err := isCollectionLoaded(ctx, cit.queryCoord, collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if loaded {
|
||||
return fmt.Errorf("create index failed, collection is loaded, please release it first")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cit *createIndexTask) Execute(ctx context.Context) error {
|
||||
@ -504,7 +518,7 @@ func (dit *dropIndexTask) PreExecute(ctx context.Context) error {
|
||||
}
|
||||
dit.collectionID = collID
|
||||
|
||||
loaded, err := isCollectionLoaded(ctx, dit.queryCoord, []int64{collID})
|
||||
loaded, err := isCollectionLoaded(ctx, dit.queryCoord, collID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -21,14 +21,15 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
func TestGetIndexStateTask_Execute(t *testing.T) {
|
||||
@ -203,3 +204,87 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateIndexTask_PreExecute(t *testing.T) {
|
||||
collectionName := "collection1"
|
||||
collectionID := UniqueID(1)
|
||||
fieldName := newTestSchema().Fields[0].Name
|
||||
|
||||
Params.Init()
|
||||
ic := newMockIndexCoord()
|
||||
ctx := context.Background()
|
||||
|
||||
mockCache := newMockCache()
|
||||
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return collectionID, nil
|
||||
})
|
||||
mockCache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) {
|
||||
return newTestSchema(), nil
|
||||
})
|
||||
globalMetaCache = mockCache
|
||||
|
||||
cit := createIndexTask{
|
||||
ctx: ctx,
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_CreateIndex,
|
||||
},
|
||||
CollectionName: collectionName,
|
||||
FieldName: fieldName,
|
||||
},
|
||||
indexCoord: ic,
|
||||
queryCoord: nil,
|
||||
result: nil,
|
||||
collectionID: collectionID,
|
||||
}
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionIDs: []int64{},
|
||||
}, nil
|
||||
}
|
||||
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
|
||||
qc.updateState(commonpb.StateCode_Healthy)
|
||||
cit.queryCoord = qc
|
||||
|
||||
err := cit.PreExecute(ctx)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("coll has been loaded", func(t *testing.T) {
|
||||
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionIDs: []int64{collectionID},
|
||||
}, nil
|
||||
}
|
||||
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
|
||||
qc.updateState(commonpb.StateCode_Healthy)
|
||||
cit.queryCoord = qc
|
||||
err := cit.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("check load error", func(t *testing.T) {
|
||||
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "fail reason",
|
||||
},
|
||||
CollectionIDs: nil,
|
||||
}, errors.New("error")
|
||||
}
|
||||
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
|
||||
qc.updateState(commonpb.StateCode_Healthy)
|
||||
cit.queryCoord = qc
|
||||
err := cit.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@ -2136,6 +2136,17 @@ func Test_createIndexTask_PreExecute(t *testing.T) {
|
||||
FieldName: fieldName,
|
||||
},
|
||||
}
|
||||
showCollectionMock := func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
CollectionIDs: []int64{},
|
||||
}, nil
|
||||
}
|
||||
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
|
||||
qc.updateState(commonpb.StateCode_Healthy)
|
||||
cit.queryCoord = qc
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
cache := newMockCache()
|
||||
|
||||
@ -836,7 +836,7 @@ func validateIndexName(indexName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int64) (bool, error) {
|
||||
func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collID int64) (bool, error) {
|
||||
// get all loading collections
|
||||
resp, err := qc.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
|
||||
CollectionIDs: nil,
|
||||
@ -848,23 +848,18 @@ func isCollectionLoaded(ctx context.Context, qc types.QueryCoord, collIDs []int6
|
||||
return false, errors.New(resp.Status.Reason)
|
||||
}
|
||||
|
||||
loaded := false
|
||||
LOOP:
|
||||
for _, loadedCollID := range resp.GetCollectionIDs() {
|
||||
for _, collID := range collIDs {
|
||||
if collID == loadedCollID {
|
||||
loaded = true
|
||||
break LOOP
|
||||
}
|
||||
if collID == loadedCollID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return loaded, nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64, partIDs []int64) (bool, error) {
|
||||
func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, partIDs []int64) (bool, error) {
|
||||
// get all loading collections
|
||||
resp, err := qc.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
|
||||
CollectionID: collIDs,
|
||||
CollectionID: collID,
|
||||
PartitionIDs: nil,
|
||||
})
|
||||
if err != nil {
|
||||
@ -874,15 +869,12 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collIDs int64,
|
||||
return false, errors.New(resp.Status.Reason)
|
||||
}
|
||||
|
||||
loaded := false
|
||||
LOOP:
|
||||
for _, loadedPartID := range resp.GetPartitionIDs() {
|
||||
for _, partID := range partIDs {
|
||||
if partID == loadedPartID {
|
||||
loaded = true
|
||||
break LOOP
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return loaded, nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -825,7 +825,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
|
||||
}
|
||||
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
|
||||
qc.updateState(commonpb.StateCode_Healthy)
|
||||
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID})
|
||||
loaded, err := isCollectionLoaded(ctx, qc, collID)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, loaded)
|
||||
})
|
||||
@ -843,7 +843,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
|
||||
}
|
||||
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
|
||||
qc.updateState(commonpb.StateCode_Healthy)
|
||||
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID})
|
||||
loaded, err := isCollectionLoaded(ctx, qc, collID)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, loaded)
|
||||
})
|
||||
@ -861,7 +861,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
|
||||
}
|
||||
qc := NewQueryCoordMock(withValidShardLeaders(), SetQueryCoordShowCollectionsFunc(showCollectionMock))
|
||||
qc.updateState(commonpb.StateCode_Healthy)
|
||||
loaded, err := isCollectionLoaded(ctx, qc, []int64{collID})
|
||||
loaded, err := isCollectionLoaded(ctx, qc, collID)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, loaded)
|
||||
})
|
||||
|
||||
@ -1398,9 +1398,9 @@ class TestIndexString(TestcaseBase):
|
||||
data = cf.gen_default_list_data(ct.default_nb)
|
||||
collection_w.insert(data=data)
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index, index_name="vector_flat")
|
||||
collection_w.load()
|
||||
index, _ = self.index_wrap.init_index(collection_w.collection, default_string_field_name,
|
||||
default_string_index_params)
|
||||
collection_w.load()
|
||||
cf.assert_equal_index(index, collection_w.indexes[0])
|
||||
assert collection_w.num_entities == default_nb
|
||||
|
||||
|
||||
@ -748,7 +748,7 @@ class TestQueryParams(TestcaseBase):
|
||||
expected: verify query result
|
||||
"""
|
||||
# init collection with fields: int64, float, float_vec
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True)[0:2]
|
||||
collection_w, vectors = self.init_collection_general(prefix, insert_data=True, is_index=True)[0:2]
|
||||
df = vectors[0]
|
||||
|
||||
# query with output_fields=["*", float_vector)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user