enhance: cherry pick patch of new DDL framework and CDC 3 (#45280)

issue: #43897, #44123
pr: #45266
also pick pr: #45237, #45264,#45244,#45275

fix: kafka should auto reset the offset from earliest to read (#45237)

issue: #44172, #45210, #44851,#45244

kafka will auto reset the offset to "latest" if the offset is
Out-of-range. the recovery of milvus wal cannot read any message from
that. So once the offset is out-of-range, kafka should read from eariest
to read the latest uncleared data.


https://kafka.apache.org/documentation/#consumerconfigs_auto.offset.reset

enhance: support alter collection/database with WAL-based DDL framework
(#45266)

issue: #43897

- Alter collection/database is implemented by WAL-based DDL framework
now.
- Support AlterCollection/AlterDatabase in wal now.
- Alter operation can be synced by new CDC now.
- Refactor some UT for alter DDL.

fix: milvus role cannot stop at initializing state (#45244)

issue: #45243

fix: support upgrading from 2.6.x -> 2.6.5 (#45264)

issue: #43897

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-11-04 20:21:37 +08:00 committed by GitHub
parent 9d7ef929e1
commit 122d024df4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 2207 additions and 4683 deletions

View File

@ -21,12 +21,14 @@ import (
"fmt"
"os"
"os/signal"
"reflect"
"runtime/debug"
"strings"
"sync"
"syscall"
"time"
"github.com/cockroachdb/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/samber/lo"
@ -50,10 +52,10 @@ import (
rocksmqimpl "github.com/milvus-io/milvus/pkg/v2/mq/mqimpl/rocksmq/server"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/expr"
"github.com/milvus-io/milvus/pkg/v2/util/gc"
"github.com/milvus-io/milvus/pkg/v2/util/generic"
"github.com/milvus-io/milvus/pkg/v2/util/logutil"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -109,35 +111,29 @@ func runComponent[T component](ctx context.Context,
runWg *sync.WaitGroup,
creator func(context.Context, dependency.Factory) (T, error),
metricRegister func(*prometheus.Registry),
) component {
var role T
) *conc.Future[component] {
sign := make(chan struct{})
go func() {
future := conc.Go(func() (component, error) {
factory := dependency.NewFactory(localMsg)
var err error
role, err = creator(ctx, factory)
role, err := creator(ctx, factory)
if err != nil {
panic(err)
return nil, errors.Wrap(err, "create component failed")
}
if err := role.Prepare(); err != nil {
panic(err)
return nil, errors.Wrap(err, "prepare component failed")
}
close(sign)
healthz.Register(role)
metricRegister(Registry.GoRegistry)
if err := role.Run(); err != nil {
panic(err)
return nil, errors.Wrap(err, "run component failed")
}
runWg.Done()
}()
return role, nil
})
<-sign
healthz.Register(role)
metricRegister(Registry.GoRegistry)
if generic.IsZero(role) {
return nil
}
return role
return future
}
// MilvusRoles decides which components are brought up with Milvus.
@ -177,18 +173,15 @@ func (mr *MilvusRoles) printLDPreLoad() {
}
}
func (mr *MilvusRoles) runProxy(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component {
wg.Add(1)
func (mr *MilvusRoles) runProxy(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *conc.Future[component] {
return runComponent(ctx, localMsg, wg, components.NewProxy, metrics.RegisterProxy)
}
func (mr *MilvusRoles) runMixCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component {
wg.Add(1)
func (mr *MilvusRoles) runMixCoord(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *conc.Future[component] {
return runComponent(ctx, localMsg, wg, components.NewMixCoord, metrics.RegisterMixCoord)
}
func (mr *MilvusRoles) runQueryNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component {
wg.Add(1)
func (mr *MilvusRoles) runQueryNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *conc.Future[component] {
// clear local storage
queryDataLocalPath := pathutil.GetPath(pathutil.RootCachePath, 0)
if !paramtable.Get().CommonCfg.EnablePosixMode.GetAsBool() {
@ -199,21 +192,73 @@ func (mr *MilvusRoles) runQueryNode(ctx context.Context, localMsg bool, wg *sync
return runComponent(ctx, localMsg, wg, components.NewQueryNode, metrics.RegisterQueryNode)
}
func (mr *MilvusRoles) runStreamingNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component {
wg.Add(1)
func (mr *MilvusRoles) runStreamingNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *conc.Future[component] {
return runComponent(ctx, localMsg, wg, components.NewStreamingNode, metrics.RegisterStreamingNode)
}
func (mr *MilvusRoles) runDataNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component {
wg.Add(1)
func (mr *MilvusRoles) runDataNode(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *conc.Future[component] {
return runComponent(ctx, localMsg, wg, components.NewDataNode, metrics.RegisterDataNode)
}
func (mr *MilvusRoles) runCDC(ctx context.Context, localMsg bool, wg *sync.WaitGroup) component {
wg.Add(1)
func (mr *MilvusRoles) runCDC(ctx context.Context, localMsg bool, wg *sync.WaitGroup) *conc.Future[component] {
return runComponent(ctx, localMsg, wg, components.NewCDC, metrics.RegisterCDC)
}
func (mr *MilvusRoles) waitForAllComponentsReady(cancel context.CancelFunc, componentFutureMap map[string]*conc.Future[component]) (map[string]component, error) {
roles := make([]string, 0, len(componentFutureMap))
futures := make([]*conc.Future[component], 0, len(componentFutureMap))
for role, future := range componentFutureMap {
roles = append(roles, role)
futures = append(futures, future)
}
selectCases := make([]reflect.SelectCase, 1+len(componentFutureMap))
selectCases[0] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(mr.closed),
}
for i, future := range futures {
selectCases[i+1] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(future.Inner()),
}
}
componentMap := make(map[string]component, len(componentFutureMap))
readyCount := 0
var gerr error
for {
index, _, _ := reflect.Select(selectCases)
if index == 0 {
cancel()
log.Warn("components are not ready before closing, wait for the start of components to be canceled...")
} else {
role := roles[index-1]
component, err := futures[index-1].Await()
readyCount++
if err != nil {
if gerr == nil {
gerr = errors.Wrapf(err, "component %s is not ready before closing", role)
cancel()
}
log.Warn("component is not ready before closing", zap.String("role", role), zap.Error(err))
} else {
componentMap[role] = component
log.Info("component is ready", zap.String("role", role))
}
}
selectCases[index] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(nil),
}
if readyCount == len(componentFutureMap) {
break
}
}
if gerr != nil {
return nil, errors.Wrap(gerr, "failed to wait for all components ready")
}
return componentMap, nil
}
func (mr *MilvusRoles) setupLogger() {
params := paramtable.Get()
logConfig := log.Config{
@ -406,48 +451,51 @@ func (mr *MilvusRoles) Run() {
var wg sync.WaitGroup
local := mr.Local
componentMap := make(map[string]component)
var mixCoord component
var proxy, dataNode, queryNode, streamingNode, cdc component
componentFutureMap := make(map[string]*conc.Future[component])
if (mr.EnableRootCoord && mr.EnableDataCoord && mr.EnableQueryCoord) || mr.EnableMixCoord {
paramtable.SetLocalComponentEnabled(typeutil.MixCoordRole)
mixCoord = mr.runMixCoord(ctx, local, &wg)
componentMap[typeutil.MixCoordRole] = mixCoord
mixCoord := mr.runMixCoord(ctx, local, &wg)
componentFutureMap[typeutil.MixCoordRole] = mixCoord
}
if mr.EnableQueryNode {
paramtable.SetLocalComponentEnabled(typeutil.QueryNodeRole)
queryNode = mr.runQueryNode(ctx, local, &wg)
componentMap[typeutil.QueryNodeRole] = queryNode
queryNode := mr.runQueryNode(ctx, local, &wg)
componentFutureMap[typeutil.QueryNodeRole] = queryNode
}
if mr.EnableDataNode {
paramtable.SetLocalComponentEnabled(typeutil.DataNodeRole)
dataNode = mr.runDataNode(ctx, local, &wg)
componentMap[typeutil.DataNodeRole] = dataNode
dataNode := mr.runDataNode(ctx, local, &wg)
componentFutureMap[typeutil.DataNodeRole] = dataNode
}
if mr.EnableProxy {
paramtable.SetLocalComponentEnabled(typeutil.ProxyRole)
proxy = mr.runProxy(ctx, local, &wg)
componentMap[typeutil.ProxyRole] = proxy
proxy := mr.runProxy(ctx, local, &wg)
componentFutureMap[typeutil.ProxyRole] = proxy
}
if mr.EnableStreamingNode {
// Before initializing the local streaming node, make sure the local registry is ready.
paramtable.SetLocalComponentEnabled(typeutil.StreamingNodeRole)
streamingNode = mr.runStreamingNode(ctx, local, &wg)
componentMap[typeutil.StreamingNodeRole] = streamingNode
streamingNode := mr.runStreamingNode(ctx, local, &wg)
componentFutureMap[typeutil.StreamingNodeRole] = streamingNode
}
if mr.EnableCDC {
paramtable.SetLocalComponentEnabled(typeutil.CDCRole)
cdc = mr.runCDC(ctx, local, &wg)
componentMap[typeutil.CDCRole] = cdc
cdc := mr.runCDC(ctx, local, &wg)
componentFutureMap[typeutil.CDCRole] = cdc
}
wg.Wait()
componentMap, err := mr.waitForAllComponentsReady(cancel, componentFutureMap)
if err != nil {
log.Warn("Failed to wait for all components ready", zap.Error(err))
return
}
log.Info("All components are ready", zap.Strings("roles", lo.Keys(componentMap)))
http.RegisterStopComponent(func(role string) error {
if len(role) == 0 || componentMap[role] == nil {
@ -505,6 +553,13 @@ func (mr *MilvusRoles) Run() {
<-mr.closed
mixCoord := componentMap[typeutil.MixCoordRole]
streamingNode := componentMap[typeutil.StreamingNodeRole]
queryNode := componentMap[typeutil.QueryNodeRole]
dataNode := componentMap[typeutil.DataNodeRole]
cdc := componentMap[typeutil.CDCRole]
proxy := componentMap[typeutil.ProxyRole]
// stop coordinators first
coordinators := []component{mixCoord}
for idx, coord := range coordinators {

View File

@ -19,6 +19,7 @@ import (
// Therefore, we need to read configs from 2.3.x and modify meta data if necessary.
type MmapMigration struct {
rootcoordMeta rootcoord.IMetaTable
rootCoordCatalog metastore.RootCoordCatalog
tsoAllocator tso.Allocator
datacoordCatalog metastore.DataCoordCatalog
}
@ -58,8 +59,7 @@ func (m *MmapMigration) MigrateRootCoordCollection(ctx context.Context) {
newColl.Properties = updateOrAddMmapKey(newColl.Properties, common.MmapEnabledKey, "true")
fmt.Printf("migrate collection %v, %s\n", collection.CollectionID, collection.Name)
if err := m.rootcoordMeta.AlterCollection(ctx, collection, newColl, ts, false); err != nil {
if err := m.rootCoordCatalog.AlterCollection(ctx, collection, newColl, metastore.MODIFY, ts, false); err != nil {
panic(err)
}
}
@ -100,9 +100,10 @@ func (m *MmapMigration) MigrateIndexCoordCollection(ctx context.Context) {
}
}
func NewMmapMigration(rootcoordMeta rootcoord.IMetaTable, tsoAllocator tso.Allocator, datacoordCatalog metastore.DataCoordCatalog) *MmapMigration {
func NewMmapMigration(rootcoordMeta rootcoord.IMetaTable, tsoAllocator tso.Allocator, datacoordCatalog metastore.DataCoordCatalog, rootCoordCatalog metastore.RootCoordCatalog) *MmapMigration {
return &MmapMigration{
rootcoordMeta: rootcoordMeta,
rootCoordCatalog: rootCoordCatalog,
tsoAllocator: tsoAllocator,
datacoordCatalog: datacoordCatalog,
}

View File

@ -42,9 +42,9 @@ func main() {
}
fmt.Printf("MmapDirPath: %s\n", paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue())
allocator := prepareTsoAllocator()
rootCoordMeta := prepareRootCoordMeta(context.Background(), allocator)
rootCoordMeta, rootCoordCatalog := prepareRootCoordMeta(context.Background(), allocator)
dataCoordCatalog := prepareDataCoordCatalog()
m := mmap.NewMmapMigration(rootCoordMeta, allocator, dataCoordCatalog)
m := mmap.NewMmapMigration(rootCoordMeta, allocator, dataCoordCatalog, rootCoordCatalog)
m.Migrate(context.Background())
}
@ -117,7 +117,7 @@ func metaKVCreator() (kv.MetaKv, error) {
etcdkv.WithRequestTimeout(paramtable.Get().ServiceParam.EtcdCfg.RequestTimeout.GetAsDuration(time.Millisecond))), nil
}
func prepareRootCoordMeta(ctx context.Context, allocator tso.Allocator) rootcoord.IMetaTable {
func prepareRootCoordMeta(ctx context.Context, allocator tso.Allocator) (rootcoord.IMetaTable, metastore.RootCoordCatalog) {
var catalog metastore.RootCoordCatalog
var err error
@ -158,7 +158,7 @@ func prepareRootCoordMeta(ctx context.Context, allocator tso.Allocator) rootcoor
panic(err)
}
return meta
return meta, catalog
}
func prepareDataCoordCatalog() metastore.DataCoordCatalog {

2
go.mod
View File

@ -74,7 +74,7 @@ require (
github.com/jolestar/go-commons-pool/v2 v2.1.2
github.com/magiconair/properties v1.8.7
github.com/milvus-io/milvus/client/v2 v2.0.0-00010101000000-000000000000
github.com/milvus-io/milvus/pkg/v2 v2.6.3
github.com/milvus-io/milvus/pkg/v2 v2.6.4
github.com/pkg/errors v0.9.1
github.com/remeh/sizedwaitgroup v1.0.0
github.com/shirou/gopsutil/v4 v4.24.10

View File

@ -112,6 +112,7 @@ func initStreamingSystem(t *testing.T) {
<-ctx.Done()
return ctx.Err()
}).Maybe()
b.EXPECT().WaitUntilWALbasedDDLReady(mock.Anything).Return(nil).Maybe()
b.EXPECT().Close().Return().Maybe()
balance.Register(b)
channel.ResetStaticPChannelStatsManager()

View File

@ -1,52 +0,0 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package msghandlerimpl
import (
"context"
"github.com/milvus-io/milvus/internal/flushcommon/broker"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
type msgHandlerImpl struct {
broker broker.Broker
}
func (m *msgHandlerImpl) HandleCreateSegment(ctx context.Context, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
panic("unreachable code")
}
func (m *msgHandlerImpl) HandleFlush(flushMsg message.ImmutableFlushMessageV2) error {
panic("unreachable code")
}
func (m *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error {
panic("unreachable code")
}
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error {
panic("unreachable code")
}
func NewMsgHandlerImpl(broker broker.Broker) *msgHandlerImpl {
return &msgHandlerImpl{
broker: broker,
}
}

View File

@ -1,43 +0,0 @@
/*
* Licensed to the LF AI & Data foundation under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package msghandlerimpl
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/flushcommon/broker"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func TestMsgHandlerImpl(t *testing.T) {
paramtable.Init()
b := broker.NewMockBroker(t)
m := NewMsgHandlerImpl(b)
assert.Panics(t, func() {
m.HandleCreateSegment(nil, nil)
})
assert.Panics(t, func() {
m.HandleFlush(nil)
})
assert.Panics(t, func() {
m.HandleManualFlush(nil)
})
}

View File

@ -114,7 +114,7 @@ func (m *delegatorMsgstreamAdaptor) Seek(ctx context.Context, msgPositions []*ms
// only consume messages with timestamp >= position timestamp
options.DeliverFilterTimeTickGTE(position.GetTimestamp()),
// only consume insert and delete messages
options.DeliverFilterMessageType(message.MessageTypeInsert, message.MessageTypeDelete, message.MessageTypeSchemaChange),
options.DeliverFilterMessageType(message.MessageTypeInsert, message.MessageTypeDelete, message.MessageTypeSchemaChange, message.MessageTypeAlterCollection),
},
MessageHandler: handler,
})

View File

@ -284,6 +284,19 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
)
logger.Info("receive schema change message")
ddn.msgHandler.HandleSchemaChange(ddn.ctx, schemaMsg.SchemaChangeMessage)
case commonpb.MsgType_AlterCollection:
alterCollectionMsg := msg.(*adaptor.AlterCollectionMessageBody)
logger := log.With(
zap.String("vchannel", ddn.Name()),
zap.Int32("msgType", int32(msg.Type())),
zap.Uint64("timetick", alterCollectionMsg.AlterCollectionMessage.TimeTick()),
)
logger.Info("receive put collection message")
if err := ddn.msgHandler.HandleAlterCollection(ddn.ctx, alterCollectionMsg.AlterCollectionMessage); err != nil {
logger.Warn("handle put collection message failed", zap.Error(err))
} else {
logger.Info("handle put collection message success")
}
}
}

View File

@ -32,6 +32,8 @@ type MsgHandler interface {
HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error
HandleSchemaChange(ctx context.Context, schemaChangeMsg message.ImmutableSchemaChangeMessageV2) error
HandleAlterCollection(ctx context.Context, alterCollectionMsg message.ImmutableAlterCollectionMessageV2) error
}
func ConvertInternalImportFile(file *msgpb.ImportFile, _ int) *internalpb.ImportFile {

View File

@ -23,8 +23,10 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/common"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
// TODO: These collection is dirty implementation and easy to be broken, we should drop it in the future.
type Collection struct {
TenantID string
DBID int64
@ -48,6 +50,7 @@ type Collection struct {
State pb.CollectionState
EnableDynamicField bool
UpdateTimestamp uint64
SchemaVersion int32
}
func (c *Collection) Available() bool {
@ -78,6 +81,7 @@ func (c *Collection) ShallowClone() *Collection {
EnableDynamicField: c.EnableDynamicField,
Functions: c.Functions,
UpdateTimestamp: c.UpdateTimestamp,
SchemaVersion: c.SchemaVersion,
}
}
@ -105,6 +109,7 @@ func (c *Collection) Clone() *Collection {
EnableDynamicField: c.EnableDynamicField,
Functions: CloneFunctions(c.Functions),
UpdateTimestamp: c.UpdateTimestamp,
SchemaVersion: c.SchemaVersion,
}
}
@ -130,6 +135,33 @@ func (c *Collection) Equal(other Collection) bool {
c.EnableDynamicField == other.EnableDynamicField
}
func (c *Collection) ApplyUpdates(header *message.AlterCollectionMessageHeader, body *message.AlterCollectionMessageBody) {
updateMask := header.UpdateMask
updates := body.Updates
for _, field := range updateMask.GetPaths() {
switch field {
case message.FieldMaskDB:
c.DBID = updates.DbId
c.DBName = updates.DbName
case message.FieldMaskCollectionName:
c.Name = updates.CollectionName
case message.FieldMaskCollectionDescription:
c.Description = updates.Description
case message.FieldMaskCollectionConsistencyLevel:
c.ConsistencyLevel = updates.ConsistencyLevel
case message.FieldMaskCollectionProperties:
c.Properties = updates.Properties
case message.FieldMaskCollectionSchema:
c.AutoID = updates.Schema.AutoID
c.Fields = UnmarshalFieldModels(updates.Schema.Fields)
c.EnableDynamicField = updates.Schema.EnableDynamicField
c.Functions = UnmarshalFunctionModels(updates.Schema.Functions)
c.StructArrayFields = UnmarshalStructArrayFieldModels(updates.Schema.StructArrayFields)
c.SchemaVersion = updates.Schema.Version
}
}
}
func UnmarshalCollectionModel(coll *pb.CollectionInfo) *Collection {
if coll == nil {
return nil
@ -165,6 +197,7 @@ func UnmarshalCollectionModel(coll *pb.CollectionInfo) *Collection {
Properties: coll.Properties,
EnableDynamicField: coll.Schema.EnableDynamicField,
UpdateTimestamp: coll.UpdateTimestamp,
SchemaVersion: coll.Schema.Version,
}
}
@ -215,6 +248,7 @@ func marshalCollectionModelWithConfig(coll *Collection, c *config) *pb.Collectio
AutoID: coll.AutoID,
EnableDynamicField: coll.EnableDynamicField,
DbName: coll.DBName,
Version: coll.SchemaVersion,
}
if c.withFields {

View File

@ -22,6 +22,53 @@ func (_m *MockMsgHandler) EXPECT() *MockMsgHandler_Expecter {
return &MockMsgHandler_Expecter{mock: &_m.Mock}
}
// HandleAlterCollection provides a mock function with given fields: ctx, alterCollectionMsg
func (_m *MockMsgHandler) HandleAlterCollection(ctx context.Context, alterCollectionMsg message.ImmutableAlterCollectionMessageV2) error {
ret := _m.Called(ctx, alterCollectionMsg)
if len(ret) == 0 {
panic("no return value specified for HandleAlterCollection")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, message.ImmutableAlterCollectionMessageV2) error); ok {
r0 = rf(ctx, alterCollectionMsg)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockMsgHandler_HandleAlterCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HandleAlterCollection'
type MockMsgHandler_HandleAlterCollection_Call struct {
*mock.Call
}
// HandleAlterCollection is a helper method to define mock.On call
// - ctx context.Context
// - alterCollectionMsg message.ImmutableAlterCollectionMessageV2
func (_e *MockMsgHandler_Expecter) HandleAlterCollection(ctx interface{}, alterCollectionMsg interface{}) *MockMsgHandler_HandleAlterCollection_Call {
return &MockMsgHandler_HandleAlterCollection_Call{Call: _e.mock.On("HandleAlterCollection", ctx, alterCollectionMsg)}
}
func (_c *MockMsgHandler_HandleAlterCollection_Call) Run(run func(ctx context.Context, alterCollectionMsg message.ImmutableAlterCollectionMessageV2)) *MockMsgHandler_HandleAlterCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(message.ImmutableAlterCollectionMessageV2))
})
return _c
}
func (_c *MockMsgHandler_HandleAlterCollection_Call) Return(_a0 error) *MockMsgHandler_HandleAlterCollection_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockMsgHandler_HandleAlterCollection_Call) RunAndReturn(run func(context.Context, message.ImmutableAlterCollectionMessageV2) error) *MockMsgHandler_HandleAlterCollection_Call {
_c.Call.Return(run)
return _c
}
// HandleCreateSegment provides a mock function with given fields: ctx, createSegmentMsg
func (_m *MockMsgHandler) HandleCreateSegment(ctx context.Context, createSegmentMsg message.ImmutableCreateSegmentMessageV2) error {
ret := _m.Called(ctx, createSegmentMsg)

View File

@ -573,6 +573,52 @@ func (_c *MockBalancer_UpdateReplicateConfiguration_Call) RunAndReturn(run func(
return _c
}
// WaitUntilWALbasedDDLReady provides a mock function with given fields: ctx
func (_m *MockBalancer) WaitUntilWALbasedDDLReady(ctx context.Context) error {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for WaitUntilWALbasedDDLReady")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBalancer_WaitUntilWALbasedDDLReady_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WaitUntilWALbasedDDLReady'
type MockBalancer_WaitUntilWALbasedDDLReady_Call struct {
*mock.Call
}
// WaitUntilWALbasedDDLReady is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockBalancer_Expecter) WaitUntilWALbasedDDLReady(ctx interface{}) *MockBalancer_WaitUntilWALbasedDDLReady_Call {
return &MockBalancer_WaitUntilWALbasedDDLReady_Call{Call: _e.mock.On("WaitUntilWALbasedDDLReady", ctx)}
}
func (_c *MockBalancer_WaitUntilWALbasedDDLReady_Call) Run(run func(ctx context.Context)) *MockBalancer_WaitUntilWALbasedDDLReady_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockBalancer_WaitUntilWALbasedDDLReady_Call) Return(_a0 error) *MockBalancer_WaitUntilWALbasedDDLReady_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBalancer_WaitUntilWALbasedDDLReady_Call) RunAndReturn(run func(context.Context) error) *MockBalancer_WaitUntilWALbasedDDLReady_Call {
_c.Call.Return(run)
return _c
}
// WatchChannelAssignments provides a mock function with given fields: ctx, cb
func (_m *MockBalancer) WatchChannelAssignments(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
ret := _m.Called(ctx, cb)

View File

@ -34,12 +34,14 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/rgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/json"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/metastore/kv/querycoord"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/querycoordv2/balance"
"github.com/milvus-io/milvus/internal/querycoordv2/checkers"
@ -51,6 +53,8 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
sbalance "github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/util/proxyutil"
@ -116,6 +120,16 @@ func initStreamingSystem() {
wal.EXPECT().ControlChannel().Return(funcutil.GetControlChannel("by-dev-rootcoord-dml_0"))
streaming.SetWALForTest(wal)
snmanager.ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WaitUntilWALbasedDDLReady(mock.Anything).Return(nil).Maybe()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
}).Maybe()
b.EXPECT().Close().Return().Maybe()
sbalance.Register(b)
bapi := mock_broadcaster.NewMockBroadcastAPI(t)
bapi.EXPECT().Broadcast(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
results := make(map[string]*message.AppendResult)

View File

@ -142,6 +142,13 @@ func (fNode *filterNode) filtrate(c *Collection, msg msgstream.TsMsg) error {
return merr.WrapErrCollectionNotFound(header.GetCollectionId())
}
return nil
case commonpb.MsgType_AlterCollection:
putCollectionMsg := msg.(*adaptor.AlterCollectionMessageBody)
header := putCollectionMsg.AlterCollectionMessage.Header()
if header.GetCollectionId() != fNode.collectionID {
return merr.WrapErrCollectionNotFound(header.GetCollectionId())
}
return nil
default:
return merr.WrapErrParameterInvalid("msgType is Insert or Delete", "not")
}

View File

@ -23,6 +23,7 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/messageutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
)
@ -61,6 +62,14 @@ func (msg *insertNodeMsg) append(taskMsg msgstream.TsMsg) error {
}
msg.schema = body.GetSchema()
msg.schemaVersion = taskMsg.BeginTs()
case commonpb.MsgType_AlterCollection:
putCollectionMsg := taskMsg.(*adaptor.AlterCollectionMessageBody)
header := putCollectionMsg.AlterCollectionMessage.Header()
if messageutil.IsSchemaChange(header) {
body := putCollectionMsg.AlterCollectionMessage.MustBody()
msg.schema = body.GetUpdates().GetSchema()
msg.schemaVersion = taskMsg.BeginTs()
}
default:
return merr.WrapErrParameterInvalid("msgType is Insert or Delete", "not")
}

View File

@ -1,136 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/log"
)
type addCollectionFieldTask struct {
baseTask
Req *milvuspb.AddCollectionFieldRequest
fieldSchema *schemapb.FieldSchema
}
func (t *addCollectionFieldTask) Prepare(ctx context.Context) error {
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_AddCollectionField); err != nil {
return err
}
t.fieldSchema = &schemapb.FieldSchema{}
err := proto.Unmarshal(t.Req.Schema, t.fieldSchema)
if err != nil {
return err
}
if err := checkFieldSchema([]*schemapb.FieldSchema{t.fieldSchema}); err != nil {
return err
}
return nil
}
func (t *addCollectionFieldTask) Execute(ctx context.Context) error {
oldColl, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), t.ts)
if err != nil {
log.Ctx(ctx).Warn("get collection failed during add field",
zap.String("collectionName", t.Req.GetCollectionName()), zap.Uint64("ts", t.ts))
return err
}
// assign field id
t.fieldSchema.FieldID = nextFieldID(oldColl)
newField := model.UnmarshalFieldModel(t.fieldSchema)
ts := t.GetTs()
t.Req.CollectionID = oldColl.CollectionID
return executeAddCollectionFieldTaskSteps(ctx, t.core, oldColl, newField, t.Req, ts)
}
func (t *addCollectionFieldTask) GetLockerKey() LockerKey {
collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0)
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(t.Req.GetDbName(), false),
NewCollectionLockerKey(collection, true),
)
}
type collInfoProvider interface {
GetDbName() string
GetCollectionName() string
GetCollectionID() int64
}
func executeAddCollectionFieldTaskSteps(ctx context.Context,
core *Core,
col *model.Collection,
newField *model.Field,
req collInfoProvider,
ts Timestamp,
) error {
redoTask := newBaseRedoTask(core.stepExecutor)
updatedCollection := col.Clone()
updatedCollection.Fields = append(updatedCollection.Fields, newField)
if newField.IsDynamic {
updatedCollection.EnableDynamicField = true
}
redoTask.AddSyncStep(&WriteSchemaChangeWALStep{
baseStep: baseStep{core: core},
collection: updatedCollection,
})
oldColl := col.Clone()
redoTask.AddSyncStep(&AddCollectionFieldStep{
baseStep: baseStep{core: core},
oldColl: oldColl,
updatedCollection: updatedCollection,
newField: newField,
})
redoTask.AddSyncStep(&BroadcastAlteredCollectionStep{
baseStep: baseStep{core: core},
req: &milvuspb.AlterCollectionRequest{
DbName: req.GetDbName(),
CollectionName: req.GetCollectionName(),
CollectionID: req.GetCollectionID(),
},
core: core,
})
// field needs to be refreshed in the cache
aliases := core.meta.ListAliasesByID(ctx, oldColl.CollectionID)
redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: core},
dbName: req.GetDbName(),
collectionNames: append(aliases, req.GetCollectionName()),
collectionID: oldColl.CollectionID,
ts: ts,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_AddCollectionField)},
})
return redoTask.Execute(ctx)
}

View File

@ -1,332 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func Test_AddCollectionFieldTask_Prepare(t *testing.T) {
t.Run("invalid msg type", func(t *testing.T) {
task := &addCollectionFieldTask{Req: &milvuspb.AddCollectionFieldRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection}}}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("check field failed", func(t *testing.T) {
fieldSchema := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int64,
DefaultValue: &schemapb.ValueField{
Data: &schemapb.ValueField_BoolData{
BoolData: false,
},
},
}
bytes, err := proto.Marshal(fieldSchema)
assert.NoError(t, err)
task := &addCollectionFieldTask{Req: &milvuspb.AddCollectionFieldRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AddCollectionField}, Schema: bytes}}
err = task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
fieldSchema := &schemapb.FieldSchema{
DataType: schemapb.DataType_Bool,
DefaultValue: &schemapb.ValueField{
Data: &schemapb.ValueField_BoolData{
BoolData: false,
},
},
}
bytes, err := proto.Marshal(fieldSchema)
assert.NoError(t, err)
task := &addCollectionFieldTask{Req: &milvuspb.AddCollectionFieldRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AddCollectionField}, Schema: bytes}}
err = task.Prepare(context.Background())
assert.NoError(t, err)
})
}
func Test_AddCollectionFieldTask_Execute(t *testing.T) {
b := mock_streaming.NewMockBroadcast(t)
wal := mock_streaming.NewMockWALAccesser(t)
wal.EXPECT().Broadcast().Return(b).Maybe()
streaming.SetWALForTest(wal)
testCollection := &model.Collection{
CollectionID: int64(1),
Fields: []*model.Field{
{
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
{
Name: "vec",
DataType: schemapb.DataType_FloatVector,
},
},
PhysicalChannelNames: []string{"dml_ch_01", "dml_ch_02"},
VirtualChannelNames: []string{"dml_ch_01", "dml_ch_02"},
}
t.Run("failed_to_get_collection", func(t *testing.T) {
metaTable := mockrootcoord.NewIMetaTable(t)
metaTable.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, "not_existed_coll", mock.Anything).Return(nil, merr.WrapErrCollectionNotFound("not_existed_coll"))
core := newTestCore(withMeta(metaTable))
task := &addCollectionFieldTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AddCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterAlias},
CollectionName: "not_existed_coll",
},
}
err := task.Execute(context.Background())
assert.Error(t, err, "error shall be return when get collection failed")
})
t.Run("write_wal_fail", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetCollectionByName(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(testCollection, nil)
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return errors.New("mock")
}
alloc := newMockIDAllocator()
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withIDAllocator(alloc))
task := &addCollectionFieldTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AddCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterAlias},
CollectionName: "coll",
},
fieldSchema: &schemapb.FieldSchema{
Name: "fid",
DataType: schemapb.DataType_Bool,
Nullable: true,
},
}
t.Run("write_schema_change_fail", func(t *testing.T) {
b.EXPECT().Append(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()
err := task.Execute(context.Background())
assert.Error(t, err)
})
})
t.Run("add field step failed", func(t *testing.T) {
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{
AppendResults: map[string]*types.AppendResult{
"dml_ch_01": {TimeTick: 100},
"dml_ch_02": {TimeTick: 101},
},
}, nil).Times(1)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetCollectionByName(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(testCollection, nil)
meta.EXPECT().AlterCollection(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(errors.New("mock"))
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{})
alloc := newMockIDAllocator()
core := newTestCore(withValidProxyManager(), withMeta(meta), withIDAllocator(alloc))
task := &addCollectionFieldTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AddCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterAlias},
CollectionName: "coll",
},
fieldSchema: &schemapb.FieldSchema{
Name: "fid",
DataType: schemapb.DataType_Bool,
Nullable: true,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("broadcast add field step failed", func(t *testing.T) {
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{
AppendResults: map[string]*types.AppendResult{
"dml_ch_01": {TimeTick: 100},
"dml_ch_02": {TimeTick: 101},
},
}, nil).Times(1)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetCollectionByName(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(testCollection, nil)
meta.EXPECT().AlterCollection(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return errors.New("mock")
}
alloc := newMockIDAllocator()
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withIDAllocator(alloc))
task := &addCollectionFieldTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AddCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterAlias},
CollectionName: "coll",
},
fieldSchema: &schemapb.FieldSchema{
Name: "fid",
DataType: schemapb.DataType_Bool,
Nullable: true,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("expire cache failed", func(t *testing.T) {
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{
AppendResults: map[string]*types.AppendResult{
"dml_ch_01": {TimeTick: 100},
"dml_ch_02": {TimeTick: 101},
},
}, nil).Times(1)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetCollectionByName(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(testCollection, nil)
meta.EXPECT().AlterCollection(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}
alloc := newMockIDAllocator()
core := newTestCore(withInvalidProxyManager(), withMeta(meta), withBroker(broker), withIDAllocator(alloc))
task := &addCollectionFieldTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AddCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterAlias},
CollectionName: "coll",
},
fieldSchema: &schemapb.FieldSchema{
Name: "fid",
DataType: schemapb.DataType_Bool,
Nullable: true,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{
AppendResults: map[string]*types.AppendResult{
"dml_ch_01": {TimeTick: 100},
"dml_ch_02": {TimeTick: 101},
},
}, nil).Times(1)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetCollectionByName(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(testCollection, nil)
meta.EXPECT().AlterCollection(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}
alloc := newMockIDAllocator()
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withIDAllocator(alloc))
task := &addCollectionFieldTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AddCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AddCollectionField},
CollectionName: "coll",
},
fieldSchema: &schemapb.FieldSchema{
Name: "fid",
DataType: schemapb.DataType_Bool,
Nullable: true,
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
})
}

View File

@ -17,93 +17,17 @@
package rootcoord
import (
"context"
"strconv"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type alterCollectionTask struct {
baseTask
Req *milvuspb.AlterCollectionRequest
}
func (a *alterCollectionTask) Prepare(ctx context.Context) error {
if a.Req.GetCollectionName() == "" {
return errors.New("alter collection failed, collection name does not exists")
}
if funcutil.SliceContain(a.Req.GetDeleteKeys(), common.EnableDynamicSchemaKey) {
return merr.WrapErrParameterInvalidMsg("cannot delete key %s, dynamic field schema could support set to true/false", common.EnableDynamicSchemaKey)
}
return nil
}
func (a *alterCollectionTask) Execute(ctx context.Context) error {
log := log.Ctx(ctx).With(
zap.String("alterCollectionTask", a.Req.GetCollectionName()),
zap.Int64("collectionID", a.Req.GetCollectionID()),
zap.Uint64("ts", a.GetTs()))
if a.Req.GetProperties() == nil && a.Req.GetDeleteKeys() == nil {
log.Warn("alter collection with empty properties and delete keys, expected to set either properties or delete keys ")
return errors.New("alter collection with empty properties and delete keys, expect to set either properties or delete keys ")
}
if len(a.Req.GetProperties()) > 0 && len(a.Req.GetDeleteKeys()) > 0 {
return errors.New("alter collection cannot provide properties and delete keys at the same time")
}
if hookutil.ContainsCipherProperties(a.Req.GetProperties(), a.Req.GetDeleteKeys()) {
log.Info("skip to alter collection due to cipher properties were detected in the properties")
return errors.New("can not alter cipher related properties")
}
oldColl, err := a.core.meta.GetCollectionByName(ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), a.GetTs())
if err != nil {
log.Warn("get collection failed during changing collection state", zap.Error(err))
return err
}
var newProperties []*commonpb.KeyValuePair
if len(a.Req.Properties) > 0 {
if IsSubsetOfProperties(a.Req.GetProperties(), oldColl.Properties) {
log.Info("skip to alter collection due to no changes were detected in the properties")
return nil
}
newProperties = MergeProperties(oldColl.Properties, a.Req.GetProperties())
} else if len(a.Req.DeleteKeys) > 0 {
newProperties = DeleteProperties(oldColl.Properties, a.Req.GetDeleteKeys())
}
return executeAlterCollectionTaskSteps(ctx, a.core, oldColl, oldColl.Properties, newProperties, a.Req, a.GetTs())
}
func (a *alterCollectionTask) GetLockerKey() LockerKey {
collection := a.core.getCollectionIDStr(a.ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), a.Req.GetCollectionID())
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(a.Req.GetDbName(), false),
NewCollectionLockerKey(collection, true),
)
}
func getCollectionDescription(props ...*commonpb.KeyValuePair) (bool, string, []*commonpb.KeyValuePair) {
hasDesc := false
desc := ""
@ -123,136 +47,26 @@ func getConsistencyLevel(props ...*commonpb.KeyValuePair) (bool, commonpb.Consis
for _, p := range props {
if p.GetKey() == common.ConsistencyLevel {
value := p.GetValue()
if level, err := strconv.ParseInt(value, 10, 32); err == nil {
if _, ok := commonpb.ConsistencyLevel_name[int32(level)]; ok {
return true, commonpb.ConsistencyLevel(level)
}
} else {
if level, ok := commonpb.ConsistencyLevel_value[value]; ok {
return true, commonpb.ConsistencyLevel(level)
}
if lv, ok := unmarshalConsistencyLevel(value); ok {
return true, lv
}
}
}
return false, commonpb.ConsistencyLevel(0)
}
func executeAlterCollectionTaskSteps(ctx context.Context,
core *Core,
col *model.Collection,
oldProperties []*commonpb.KeyValuePair,
newProperties []*commonpb.KeyValuePair,
request *milvuspb.AlterCollectionRequest,
ts Timestamp,
) error {
oldColl := col.Clone()
oldColl.Properties = oldProperties
newColl := col.Clone()
if ok, level := getConsistencyLevel(newProperties...); ok {
newColl.ConsistencyLevel = level
}
if ok, desc, props := getCollectionDescription(newProperties...); ok {
newColl.Description = desc
newColl.Properties = props
// unmarshalConsistencyLevel unmarshals the consistency level from the value.
func unmarshalConsistencyLevel(value string) (commonpb.ConsistencyLevel, bool) {
if level, err := strconv.ParseInt(value, 10, 32); err == nil {
if _, ok := commonpb.ConsistencyLevel_name[int32(level)]; ok {
return commonpb.ConsistencyLevel(level), true
}
} else {
newColl.Properties = newProperties
if level, ok := commonpb.ConsistencyLevel_value[value]; ok {
return commonpb.ConsistencyLevel(level), true
}
}
tso, err := core.tsoAllocator.GenerateTSO(1)
if err == nil {
newColl.UpdateTimestamp = tso
}
redoTask := newBaseRedoTask(core.stepExecutor)
redoTask.AddSyncStep(&AlterCollectionStep{
baseStep: baseStep{core: core},
oldColl: oldColl,
newColl: newColl,
ts: ts,
})
request.CollectionID = oldColl.CollectionID
redoTask.AddSyncStep(&BroadcastAlteredCollectionStep{
baseStep: baseStep{core: core},
req: request,
core: core,
})
// properties needs to be refreshed in the cache
aliases := core.meta.ListAliasesByID(ctx, oldColl.CollectionID)
redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: core},
dbName: request.GetDbName(),
collectionNames: append(aliases, request.GetCollectionName()),
collectionID: oldColl.CollectionID,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_AlterCollection)},
})
oldReplicaNumber, _ := common.CollectionLevelReplicaNumber(oldColl.Properties)
oldResourceGroups, _ := common.CollectionLevelResourceGroups(oldColl.Properties)
newReplicaNumber, _ := common.CollectionLevelReplicaNumber(newColl.Properties)
newResourceGroups, _ := common.CollectionLevelResourceGroups(newColl.Properties)
left, right := lo.Difference(oldResourceGroups, newResourceGroups)
rgChanged := len(left) > 0 || len(right) > 0
replicaChanged := oldReplicaNumber != newReplicaNumber
if rgChanged || replicaChanged {
log.Ctx(ctx).Warn("alter collection trigger update load config",
zap.Int64("collectionID", oldColl.CollectionID),
zap.Int64("oldReplicaNumber", oldReplicaNumber),
zap.Int64("newReplicaNumber", newReplicaNumber),
zap.Strings("oldResourceGroups", oldResourceGroups),
zap.Strings("newResourceGroups", newResourceGroups),
)
redoTask.AddAsyncStep(NewSimpleStep("", func(ctx context.Context) ([]nestedStep, error) {
resp, err := core.mixCoord.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{
CollectionIDs: []int64{oldColl.CollectionID},
ReplicaNumber: int32(newReplicaNumber),
ResourceGroups: newResourceGroups,
})
if err := merr.CheckRPCCall(resp, err); err != nil {
log.Ctx(ctx).Warn("failed to trigger update load config for collection", zap.Int64("collectionID", newColl.CollectionID), zap.Error(err))
return nil, err
}
return nil, nil
}))
}
oldReplicateEnable, _ := common.IsReplicateEnabled(oldColl.Properties)
replicateEnable, ok := common.IsReplicateEnabled(newColl.Properties)
if ok && !replicateEnable && oldReplicateEnable {
replicateID, _ := common.GetReplicateID(oldColl.Properties)
redoTask.AddAsyncStep(NewSimpleStep("send replicate end msg for collection", func(ctx context.Context) ([]nestedStep, error) {
msgPack := &msgstream.MsgPack{}
msg := &msgstream.ReplicateMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
},
ReplicateMsg: &msgpb.ReplicateMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Replicate,
Timestamp: ts,
ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true,
ReplicateID: replicateID,
},
},
IsEnd: true,
Database: newColl.DBName,
Collection: newColl.Name,
},
}
msgPack.Msgs = append(msgPack.Msgs, msg)
log.Info("send replicate end msg",
zap.String("collection", newColl.Name),
zap.String("database", newColl.DBName),
zap.String("replicateID", replicateID),
)
return nil, core.chanTimeTick.broadcastDmlChannels(newColl.PhysicalChannelNames, msgPack)
}))
}
return redoTask.Execute(ctx)
return commonpb.ConsistencyLevel_Strong, false
}
func DeleteProperties(oldProps []*commonpb.KeyValuePair, deleteKeys []string) []*commonpb.KeyValuePair {
@ -270,117 +84,6 @@ func DeleteProperties(oldProps []*commonpb.KeyValuePair, deleteKeys []string) []
return propKV
}
type alterCollectionFieldTask struct {
baseTask
Req *milvuspb.AlterCollectionFieldRequest
}
func (a *alterCollectionFieldTask) Prepare(ctx context.Context) error {
if a.Req.GetCollectionName() == "" {
return errors.New("alter collection field failed, collection name does not exists")
}
if a.Req.GetFieldName() == "" {
return errors.New("alter collection field failed, field name does not exists")
}
return nil
}
func (a *alterCollectionFieldTask) Execute(ctx context.Context) error {
if len(a.Req.GetProperties()) == 0 && len(a.Req.GetDeleteKeys()) == 0 {
return errors.New("The field properties to alter and keys to delete must not be empty at the same time")
}
oldColl, err := a.core.meta.GetCollectionByName(ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), a.ts)
if err != nil {
log.Warn("get collection failed during changing collection state",
zap.String("collectionName", a.Req.GetCollectionName()),
zap.String("fieldName", a.Req.GetFieldName()),
zap.Uint64("ts", a.ts))
return err
}
oldFieldProperties, err := GetFieldProperties(oldColl, a.Req.GetFieldName())
if err != nil {
log.Warn("get field properties failed during changing collection state", zap.Error(err))
return err
}
ts := a.GetTs()
return executeAlterCollectionFieldTaskSteps(ctx, a.core, oldColl, oldFieldProperties, a.Req, ts)
}
func (a *alterCollectionFieldTask) GetLockerKey() LockerKey {
collection := a.core.getCollectionIDStr(a.ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), 0)
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(a.Req.GetDbName(), false),
NewCollectionLockerKey(collection, true),
)
}
func executeAlterCollectionFieldTaskSteps(ctx context.Context,
core *Core,
col *model.Collection,
oldFieldProperties []*commonpb.KeyValuePair,
request *milvuspb.AlterCollectionFieldRequest,
ts Timestamp,
) error {
var err error
fieldName := request.GetFieldName()
var newFieldProperties []*commonpb.KeyValuePair
if len(request.Properties) > 0 {
newFieldProperties = UpdateFieldPropertyParams(oldFieldProperties, request.GetProperties())
} else if len(request.DeleteKeys) > 0 {
newFieldProperties = DeleteProperties(oldFieldProperties, request.GetDeleteKeys())
}
oldColl := col.Clone()
err = ResetFieldProperties(oldColl, fieldName, oldFieldProperties)
if err != nil {
return err
}
newColl := col.Clone()
err = ResetFieldProperties(newColl, fieldName, newFieldProperties)
if err != nil {
return err
}
tso, err := core.tsoAllocator.GenerateTSO(1)
if err == nil {
newColl.UpdateTimestamp = tso
}
redoTask := newBaseRedoTask(core.stepExecutor)
redoTask.AddSyncStep(&AlterCollectionStep{
baseStep: baseStep{core: core},
oldColl: oldColl,
newColl: newColl,
ts: ts,
fieldModify: true,
})
redoTask.AddSyncStep(&BroadcastAlteredCollectionStep{
baseStep: baseStep{core: core},
req: &milvuspb.AlterCollectionRequest{
Base: request.Base,
DbName: request.DbName,
CollectionName: request.CollectionName,
CollectionID: oldColl.CollectionID,
},
core: core,
})
redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: core},
dbName: request.GetDbName(),
collectionNames: []string{request.GetCollectionName()},
collectionID: oldColl.CollectionID,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_AlterCollectionField)},
})
return redoTask.Execute(ctx)
}
func ResetFieldProperties(coll *model.Collection, fieldName string, newProps []*commonpb.KeyValuePair) error {
for i, field := range coll.Fields {
if field.Name == fieldName {

View File

@ -1,628 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func Test_alterCollectionTask_Prepare(t *testing.T) {
t.Run("invalid collectionID", func(t *testing.T) {
task := &alterCollectionTask{Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}}}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("banned_delete_keys", func(t *testing.T) {
task := &alterCollectionTask{Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "test_collection",
DeleteKeys: []string{common.EnableDynamicSchemaKey},
}}
err := task.Prepare(context.Background())
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
t.Run("normal case", func(t *testing.T) {
task := &alterCollectionTask{
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
})
}
func Test_alterCollectionTask_Execute(t *testing.T) {
properties := []*commonpb.KeyValuePair{
{
Key: common.CollectionTTLConfigKey,
Value: "3600",
},
}
t.Run("properties is empty", func(t *testing.T) {
task := &alterCollectionTask{Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}}}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("failed to create alias", func(t *testing.T) {
core := newTestCore(withInvalidMeta())
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: properties,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("alter step failed", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{CollectionID: int64(1)}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(errors.New("err"))
meta.On("ListAliasesByID", mock.Anything, mock.Anything).Return([]string{})
core := newTestCore(withValidProxyManager(), withMeta(meta), withInvalidTsoAllocator())
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: properties,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("broadcast step failed", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{CollectionID: int64(1)}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.On("ListAliasesByID", mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return errors.New("err")
}
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withInvalidTsoAllocator())
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: properties,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("expire cache failed", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{CollectionID: int64(1)}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.On("ListAliasesByID", mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return errors.New("err")
}
core := newTestCore(withInvalidProxyManager(), withMeta(meta), withBroker(broker), withInvalidTsoAllocator())
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: properties,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("alter successfully", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{
CollectionID: int64(1),
Properties: []*commonpb.KeyValuePair{
{
Key: common.CollectionTTLConfigKey,
Value: "1",
},
{
Key: common.CollectionAutoCompactionKey,
Value: "true",
},
},
}, nil)
core := newTestCore(withValidProxyManager(), withMeta(meta), withInvalidTsoAllocator())
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: []*commonpb.KeyValuePair{
{
Key: common.CollectionAutoCompactionKey,
Value: "true",
},
},
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
})
t.Run("test update collection props", func(t *testing.T) {
coll := &model.Collection{
Properties: []*commonpb.KeyValuePair{
{
Key: common.CollectionTTLConfigKey,
Value: "1",
},
},
}
updateProps1 := []*commonpb.KeyValuePair{
{
Key: common.CollectionAutoCompactionKey,
Value: "true",
},
}
coll.Properties = MergeProperties(coll.Properties, updateProps1)
assert.Contains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.CollectionTTLConfigKey,
Value: "1",
})
assert.Contains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.CollectionAutoCompactionKey,
Value: "true",
})
updateProps2 := []*commonpb.KeyValuePair{
{
Key: common.CollectionTTLConfigKey,
Value: "2",
},
}
coll.Properties = MergeProperties(coll.Properties, updateProps2)
assert.Contains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.CollectionTTLConfigKey,
Value: "2",
})
assert.Contains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.CollectionAutoCompactionKey,
Value: "true",
})
updatePropsIso := []*commonpb.KeyValuePair{
{
Key: common.PartitionKeyIsolationKey,
Value: "true",
},
}
coll.Properties = MergeProperties(coll.Properties, updatePropsIso)
assert.Contains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.PartitionKeyIsolationKey,
Value: "true",
})
})
t.Run("test delete collection props", func(t *testing.T) {
coll := &model.Collection{
Properties: []*commonpb.KeyValuePair{
{
Key: common.CollectionTTLConfigKey,
Value: "1",
},
{
Key: common.CollectionAutoCompactionKey,
Value: "true",
},
},
}
deleteKeys := []string{common.CollectionTTLConfigKey}
coll.Properties = DeleteProperties(coll.Properties, deleteKeys)
assert.NotContains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.CollectionTTLConfigKey,
Value: "1",
})
assert.Contains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.CollectionAutoCompactionKey,
Value: "true",
})
deleteKeys = []string{"nonexistent.key"}
coll.Properties = DeleteProperties(coll.Properties, deleteKeys)
assert.Contains(t, coll.Properties, &commonpb.KeyValuePair{
Key: common.CollectionAutoCompactionKey,
Value: "true",
})
deleteKeys = []string{common.CollectionAutoCompactionKey}
coll.Properties = DeleteProperties(coll.Properties, deleteKeys)
assert.Empty(t, coll.Properties)
})
testFunc := func(t *testing.T, oldProps []*commonpb.KeyValuePair,
newProps []*commonpb.KeyValuePair, deleteKeys []string,
) chan *msgstream.ConsumeMsgPack {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{
CollectionID: int64(1),
Name: "cn",
DBName: "foo",
Properties: oldProps,
PhysicalChannelNames: []string{"by-dev-rootcoord-dml_1"},
}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.On("ListAliasesByID", mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}
packChan := make(chan *msgstream.ConsumeMsgPack, 10)
ticker := newChanTimeTickSync(packChan)
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker), withInvalidTsoAllocator())
task := &alterCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection},
CollectionName: "cn",
Properties: newProps,
DeleteKeys: deleteKeys,
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
return packChan
}
t.Run("alter successfully2", func(t *testing.T) {
oldProps := []*commonpb.KeyValuePair{
{
Key: common.ReplicateIDKey,
Value: "local-test",
},
}
newProps := append(properties, &commonpb.KeyValuePair{
Key: common.ReplicateEndTSKey,
Value: "10000",
})
packChan := testFunc(t, oldProps, newProps, nil)
unmarshalFactory := &msgstream.ProtoUDFactory{}
unmarshalDispatcher := unmarshalFactory.NewUnmarshalDispatcher()
time.Sleep(time.Second)
select {
case pack := <-packChan:
assert.Equal(t, commonpb.MsgType_Replicate, pack.Msgs[0].GetType())
tsMsg, err := pack.Msgs[0].Unmarshal(unmarshalDispatcher)
require.NoError(t, err)
replicateMsg := tsMsg.(*msgstream.ReplicateMsg)
assert.Equal(t, "foo", replicateMsg.ReplicateMsg.GetDatabase())
assert.Equal(t, "cn", replicateMsg.ReplicateMsg.GetCollection())
assert.True(t, replicateMsg.ReplicateMsg.GetIsEnd())
default:
assert.Fail(t, "no message sent")
}
})
t.Run("alter successfully3", func(t *testing.T) {
newProps := []*commonpb.KeyValuePair{
{
Key: common.ConsistencyLevel,
Value: "1",
},
}
testFunc(t, nil, newProps, nil)
})
t.Run("alter successfully4", func(t *testing.T) {
newProps := []*commonpb.KeyValuePair{
{
Key: common.CollectionDescription,
Value: "abc",
},
}
testFunc(t, nil, newProps, nil)
})
t.Run("alter successfully5", func(t *testing.T) {
testFunc(t, nil, nil, []string{common.CollectionDescription})
})
}
func Test_alterCollectionFieldTask_Prepare(t *testing.T) {
t.Run("invalid collection name", func(t *testing.T) {
task := &alterCollectionFieldTask{
Req: &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("invalid field name", func(t *testing.T) {
task := &alterCollectionFieldTask{
Req: &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
CollectionName: "cn",
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("normal name", func(t *testing.T) {
task := &alterCollectionFieldTask{
Req: &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
CollectionName: "cn",
FieldName: "ok",
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
})
}
func Test_alterCollectionFieldTask_Execute(t *testing.T) {
testFn := func(req *milvuspb.AlterCollectionFieldRequest, meta *mockrootcoord.IMetaTable, expectError bool) {
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}
packChan := make(chan *msgstream.ConsumeMsgPack, 10)
ticker := newChanTimeTickSync(packChan)
ticker.addDmlChannels("by-dev-rootcoord-dml_1")
core := newTestCore(withValidProxyManager(), withMeta(meta), withBroker(broker), withTtSynchronizer(ticker), withInvalidTsoAllocator())
task := &alterCollectionFieldTask{
baseTask: newBaseTask(context.Background(), core),
Req: req,
}
err := task.Execute(context.Background())
if expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
t.Run("properties and deleteKeys are empty", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
req := &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
Properties: []*commonpb.KeyValuePair{},
DeleteKeys: []string{},
}
testFn(req, meta, true)
})
t.Run("collection not found", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil, errors.New("collection not found"))
req := &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
CollectionName: "cn",
DeleteKeys: []string{common.MaxLengthKey},
}
testFn(req, meta, true)
})
t.Run("field not found", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{
CollectionID: int64(1),
Name: "cn",
DBName: "foo",
Fields: []*model.Field{},
}, nil)
req := &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
CollectionName: "cn",
DbName: "foo",
FieldName: "bar",
DeleteKeys: []string{common.MaxLengthKey},
}
testFn(req, meta, true)
})
t.Run("update properties", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{
CollectionID: int64(1),
Name: "cn",
DBName: "foo",
Fields: []*model.Field{{
FieldID: int64(1),
Name: "bar",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "50"},
},
}},
}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
req := &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
CollectionName: "cn",
DbName: "foo",
FieldName: "bar",
Properties: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "100"},
},
}
testFn(req, meta, false)
})
t.Run("delete properties", func(t *testing.T) {
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&model.Collection{
CollectionID: int64(1),
Name: "cn",
DBName: "foo",
Fields: []*model.Field{{
FieldID: int64(1),
Name: "bar",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "100"},
{Key: common.MmapEnabledKey, Value: "true"},
},
}},
}, nil)
meta.On("AlterCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
req := &milvuspb.AlterCollectionFieldRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollectionField},
CollectionName: "cn",
DbName: "foo",
FieldName: "bar",
DeleteKeys: []string{common.MmapEnabledKey},
}
testFn(req, meta, false)
})
}

View File

@ -1,115 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type alterDynamicFieldTask struct {
baseTask
Req *milvuspb.AlterCollectionRequest
oldColl *model.Collection
fieldSchema *schemapb.FieldSchema
targetValue bool
}
func (t *alterDynamicFieldTask) Prepare(ctx context.Context) error {
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_AlterCollection); err != nil {
return err
}
oldColl, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), t.ts)
if err != nil {
log.Ctx(ctx).Warn("get collection failed during alter dynamic schema",
zap.String("collectionName", t.Req.GetCollectionName()), zap.Uint64("ts", t.ts))
return err
}
t.oldColl = oldColl
if len(t.Req.GetProperties()) > 1 {
return merr.WrapErrParameterInvalidMsg("cannot alter dynamic schema with other properties")
}
// return nil for no-op
if oldColl.EnableDynamicField == t.targetValue {
return nil
}
// not support disabling since remove field not support yet.
if !t.targetValue {
return merr.WrapErrParameterInvalidMsg("dynamic schema cannot supported to be disabled")
}
// convert to add $meta json field, nullable, default value `{}`
t.fieldSchema = &schemapb.FieldSchema{
Name: common.MetaFieldName,
DataType: schemapb.DataType_JSON,
IsDynamic: true,
Nullable: true,
DefaultValue: &schemapb.ValueField{
Data: &schemapb.ValueField_BytesData{
BytesData: []byte("{}"),
},
},
}
if err := checkFieldSchema([]*schemapb.FieldSchema{t.fieldSchema}); err != nil {
return err
}
return nil
}
func (t *alterDynamicFieldTask) Execute(ctx context.Context) error {
// return nil for no-op
if t.oldColl.EnableDynamicField == t.targetValue {
log.Info("dynamic schema is same as target value",
zap.Bool("targetValue", t.targetValue),
zap.String("collectionName", t.Req.GetCollectionName()))
return nil
}
// assign field id
t.fieldSchema.FieldID = nextFieldID(t.oldColl)
// currently only add dynamic field support
// TODO check target value to remove field after supported
newField := model.UnmarshalFieldModel(t.fieldSchema)
t.Req.CollectionID = t.oldColl.CollectionID
ts := t.GetTs()
return executeAddCollectionFieldTaskSteps(ctx, t.core, t.oldColl, newField, t.Req, ts)
}
func (t *alterDynamicFieldTask) GetLockerKey() LockerKey {
collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0)
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(t.Req.GetDbName(), false),
NewCollectionLockerKey(collection, true),
)
}

View File

@ -1,251 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
type AlterDynamicSchemaTaskSuite struct {
suite.Suite
meta *mockrootcoord.IMetaTable
}
func (s *AlterDynamicSchemaTaskSuite) getDisabledCollection() *model.Collection {
return &model.Collection{
CollectionID: 1,
Name: "coll_disabled",
Fields: []*model.Field{
{
Name: "pk",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
{
Name: "vec",
FieldID: 101,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "768",
},
},
},
},
EnableDynamicField: false,
PhysicalChannelNames: []string{"dml_ch_01", "dml_ch_02"},
VirtualChannelNames: []string{"dml_ch_01", "dml_ch_02"},
}
}
func (s *AlterDynamicSchemaTaskSuite) SetupTest() {
s.meta = mockrootcoord.NewIMetaTable(s.T())
s.meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, "not_existed_coll", mock.Anything).Return(nil, merr.WrapErrCollectionNotFound("not_existed_coll")).Maybe()
s.meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, "coll_disabled", mock.Anything).Return(s.getDisabledCollection(), nil).Maybe()
s.meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, "coll_enabled", mock.Anything).Return(&model.Collection{
CollectionID: 1,
Name: "coll_enabled",
Fields: []*model.Field{
{
Name: "pk",
FieldID: 100,
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
{
Name: "vec",
FieldID: 101,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "768",
},
},
},
{
Name: "$meta",
IsDynamic: true,
DataType: schemapb.DataType_JSON,
},
},
EnableDynamicField: true,
}, nil).Maybe()
}
func (s *AlterDynamicSchemaTaskSuite) TestPrepare() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("invalid_msg_type", func() {
task := &alterDynamicFieldTask{Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection}}}
err := task.Prepare(ctx)
s.Error(err)
})
s.Run("alter_with_other_properties", func() {
task := &alterDynamicFieldTask{Req: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
Properties: []*commonpb.KeyValuePair{
{
Key: common.EnableDynamicSchemaKey,
Value: "true",
},
{
Key: "other_keys",
Value: "other_values",
},
},
}}
err := task.Prepare(ctx)
s.Error(err)
})
s.Run("disable_dynamic_field_for_disabled_coll", func() {
core := newTestCore(withMeta(s.meta))
task := &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, core),
Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "coll_disabled"},
targetValue: false,
}
err := task.Prepare(ctx)
s.NoError(err, "disabling dynamic schema on diabled collection shall be a no-op")
})
s.Run("disable_dynamic_field_for_enabled_coll", func() {
core := newTestCore(withMeta(s.meta))
task := &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, core),
Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "coll_enabled"},
targetValue: false,
}
err := task.Prepare(ctx)
s.Error(err, "disabling dynamic schema on enabled collection not supported yet")
})
s.Run("enable_dynamic_field_for_enabled_coll", func() {
core := newTestCore(withMeta(s.meta))
task := &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, core),
Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "coll_enabled"},
targetValue: true,
}
err := task.Prepare(ctx)
s.NoError(err)
})
s.Run("collection_not_exist", func() {
core := newTestCore(withMeta(s.meta))
task := &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, core),
Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "not_existed_coll"},
targetValue: true,
}
err := task.Prepare(ctx)
s.Error(err)
})
s.Run("normal_case", func() {
core := newTestCore(withMeta(s.meta))
task := &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, core),
Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "coll_disabled"},
targetValue: true,
}
err := task.Prepare(ctx)
s.NoError(err)
s.NotNil(task.fieldSchema)
})
}
func (s *AlterDynamicSchemaTaskSuite) TestExecute() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.Run("no_op", func() {
core := newTestCore(withMeta(s.meta))
task := &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, core),
Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "coll_disabled"},
oldColl: s.getDisabledCollection(),
targetValue: false,
}
err := task.Execute(ctx)
s.NoError(err)
})
s.Run("normal_case", func() {
b := mock_streaming.NewMockBroadcast(s.T())
wal := mock_streaming.NewMockWALAccesser(s.T())
wal.EXPECT().Broadcast().Return(b).Maybe()
streaming.SetWALForTest(wal)
b.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{
AppendResults: map[string]*types.AppendResult{
"dml_ch_01": {TimeTick: 100},
"dml_ch_02": {TimeTick: 101},
},
}, nil).Times(1)
s.meta.EXPECT().AlterCollection(
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
s.meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{})
broker := newMockBroker()
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return nil
}
alloc := newMockIDAllocator()
core := newTestCore(withValidProxyManager(), withMeta(s.meta), withBroker(broker), withIDAllocator(alloc))
task := &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, core),
Req: &milvuspb.AlterCollectionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_AlterCollection}, CollectionName: "coll_disabled"},
targetValue: true,
}
err := task.Prepare(ctx)
s.NoError(err)
err = task.Execute(ctx)
s.NoError(err)
})
}
func TestAlterDynamicSchemaTask(t *testing.T) {
suite.Run(t, new(AlterDynamicSchemaTaskSuite))
}

View File

@ -24,7 +24,6 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -60,7 +59,7 @@ type Broker interface {
DropCollectionIndex(ctx context.Context, collID UniqueID, partIDs []UniqueID) error
// notify observer to clean their meta cache
BroadcastAlteredCollection(ctx context.Context, req *milvuspb.AlterCollectionRequest) error
BroadcastAlteredCollection(ctx context.Context, collectionID UniqueID) error
}
type ServerBroker struct {
@ -235,19 +234,16 @@ func (b *ServerBroker) GetSegmentIndexState(ctx context.Context, collID UniqueID
return resp.GetStates(), nil
}
func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, collectionID UniqueID) error {
log.Ctx(ctx).Info("broadcasting request to alter collection",
zap.String("collectionName", req.GetCollectionName()),
zap.Int64("collectionID", req.GetCollectionID()),
zap.Any("props", req.GetProperties()),
zap.Any("deleteKeys", req.GetDeleteKeys()))
zap.Int64("collectionID", collectionID))
colMeta, err := b.s.meta.GetCollectionByID(ctx, req.GetDbName(), req.GetCollectionID(), typeutil.MaxTimestamp, false)
colMeta, err := b.s.meta.GetCollectionByID(ctx, "", collectionID, typeutil.MaxTimestamp, false)
if err != nil {
return err
}
db, err := b.s.meta.GetDatabaseByName(ctx, req.GetDbName(), typeutil.MaxTimestamp)
db, err := b.s.meta.GetDatabaseByName(ctx, colMeta.DBName, typeutil.MaxTimestamp)
if err != nil {
return err
}
@ -257,7 +253,7 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv
partitionIDs = append(partitionIDs, p.PartitionID)
}
dcReq := &datapb.AlterCollectionRequest{
CollectionID: req.GetCollectionID(),
CollectionID: collectionID,
Schema: &schemapb.CollectionSchema{
Name: colMeta.Name,
Description: colMeta.Description,
@ -281,7 +277,7 @@ func (b *ServerBroker) BroadcastAlteredCollection(ctx context.Context, req *milv
if resp.ErrorCode != commonpb.ErrorCode_Success {
return errors.New(resp.Reason)
}
log.Ctx(ctx).Info("done to broadcast request to alter collection", zap.String("collectionName", req.GetCollectionName()), zap.Int64("collectionID", req.GetCollectionID()), zap.Any("props", req.GetProperties()), zap.Any("field", colMeta.Fields))
log.Ctx(ctx).Info("done to broadcast request to alter collection", zap.String("collectionName", colMeta.Name), zap.Int64("collectionID", collectionID), zap.Any("props", colMeta.Properties), zap.Any("field", colMeta.Fields))
return nil
}

View File

@ -25,7 +25,6 @@ import (
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
@ -184,7 +183,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
c.meta = meta
b := newServerBroker(c)
ctx := context.Background()
err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{})
err := b.BroadcastAlteredCollection(ctx, 0)
assert.Error(t, err)
})
@ -202,7 +201,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
c.meta = meta
b := newServerBroker(c)
ctx := context.Background()
err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{})
err := b.BroadcastAlteredCollection(ctx, 0)
assert.Error(t, err)
})
@ -220,7 +219,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
c.meta = meta
b := newServerBroker(c)
ctx := context.Background()
err := b.BroadcastAlteredCollection(ctx, &milvuspb.AlterCollectionRequest{})
err := b.BroadcastAlteredCollection(ctx, 0)
assert.Error(t, err)
})
@ -238,11 +237,7 @@ func TestServerBroker_BroadcastAlteredCollection(t *testing.T) {
c.meta = meta
b := newServerBroker(c)
ctx := context.Background()
req := &milvuspb.AlterCollectionRequest{
CollectionID: 1,
}
err := b.BroadcastAlteredCollection(ctx, req)
err := b.BroadcastAlteredCollection(ctx, 1)
assert.NoError(t, err)
})
}

View File

@ -360,6 +360,7 @@ func (t *createCollectionTask) prepareSchema(ctx context.Context) error {
// Set properties for persistent
t.body.CollectionSchema.Properties = t.Req.GetProperties()
t.body.CollectionSchema.Version = 0
t.appendSysFields(t.body.CollectionSchema)
return nil
}

View File

@ -75,6 +75,7 @@ func (c *DDLCallback) registerAliasCallbacks() {
// registerCollectionCallbacks registers the collection callbacks.
func (c *DDLCallback) registerCollectionCallbacks() {
registry.RegisterCreateCollectionV1AckCallback(c.createCollectionV1AckCallback)
registry.RegisterAlterCollectionV2AckCallback(c.alterCollectionV2AckCallback)
registry.RegisterDropCollectionV1AckCallback(c.dropCollectionV1AckCallback)
}

View File

@ -0,0 +1,95 @@
package rootcoord
import (
"context"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// broadcastAlterCollectionForAddField broadcasts the put collection message for add field.
func (c *Core) broadcastAlterCollectionForAddField(ctx context.Context, req *milvuspb.AddCollectionFieldRequest) error {
broadcaster, err := startBroadcastWithCollectionLock(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
// check if the collection is created.
coll, err := c.meta.GetCollectionByName(ctx, req.GetDbName(), req.GetCollectionName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
// check if the field schema is illegal.
fieldSchema := &schemapb.FieldSchema{}
if err = proto.Unmarshal(req.Schema, fieldSchema); err != nil {
return errors.Wrap(err, "failed to unmarshal field schema")
}
if err := checkFieldSchema([]*schemapb.FieldSchema{fieldSchema}); err != nil {
return errors.Wrap(err, "failed to check field schema")
}
// check if the field already exists
for _, field := range coll.Fields {
if field.Name == fieldSchema.Name {
// TODO: idempotency check here.
return merr.WrapErrParameterInvalidMsg("field already exists, name: %s", fieldSchema.Name)
}
}
// assign a new field id.
fieldSchema.FieldID = nextFieldID(coll)
// build new collection schema.
schema := &schemapb.CollectionSchema{
Name: coll.Name,
Description: coll.Description,
AutoID: coll.AutoID,
Fields: model.MarshalFieldModels(coll.Fields),
StructArrayFields: model.MarshalStructArrayFieldModels(coll.StructArrayFields),
Functions: model.MarshalFunctionModels(coll.Functions),
EnableDynamicField: coll.EnableDynamicField,
Properties: coll.Properties,
Version: coll.SchemaVersion + 1,
}
schema.Fields = append(schema.Fields, fieldSchema)
cacheExpirations, err := c.getCacheExpireForCollection(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
channels := make([]string, 0, len(coll.VirtualChannelNames)+1)
channels = append(channels, streaming.WAL().ControlChannel())
channels = append(channels, coll.VirtualChannelNames...)
// broadcast the put collection v2 message.
msg := message.NewAlterCollectionMessageBuilderV2().
WithHeader(&messagespb.AlterCollectionMessageHeader{
DbId: coll.DBID,
CollectionId: coll.CollectionID,
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{message.FieldMaskCollectionSchema},
},
CacheExpirations: cacheExpirations,
}).
WithBody(&messagespb.AlterCollectionMessageBody{
Updates: &messagespb.AlterCollectionMessageUpdates{
Schema: schema,
},
}).
WithBroadcast(channels).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,127 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksAlterCollectionAddField(t *testing.T) {
core := initStreamingSystemAndCore(t)
ctx := context.Background()
dbName := "testDB" + funcutil.RandomString(10)
collectionName := "testCollection" + funcutil.RandomString(10)
// database not found
resp, err := core.AddCollectionField(ctx, &milvuspb.AddCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: getFieldSchema("field2"),
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrDatabaseNotFound)
// collection not found
resp, err = core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: util.DefaultDBName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
resp, err = core.AddCollectionField(ctx, &milvuspb.AddCollectionFieldRequest{
DbName: util.DefaultDBName,
CollectionName: collectionName,
Schema: getFieldSchema("field2"),
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrCollectionNotFound)
// atler collection field already exists
createCollectionForTest(t, ctx, core, dbName, collectionName)
resp, err = core.AddCollectionField(ctx, &milvuspb.AddCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: getFieldSchema("field1"),
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// add illegal field schema
illegalFieldSchema := &schemapb.FieldSchema{
Name: "field2",
DataType: schemapb.DataType_String,
IsPrimaryKey: true,
Nullable: true,
}
illegalFieldSchemaBytes, _ := proto.Marshal(illegalFieldSchema)
resp, err = core.AddCollectionField(ctx, &milvuspb.AddCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: illegalFieldSchemaBytes,
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// add new field successfully
resp, err = core.AddCollectionField(ctx, &milvuspb.AddCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: getFieldSchema("field2"),
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertFieldProperties(t, ctx, core, dbName, collectionName, "field1", "key1", "value1")
assertFieldExists(t, ctx, core, dbName, collectionName, "field2", 101)
assertSchemaVersion(t, ctx, core, dbName, collectionName, 1)
// add new field successfully
resp, err = core.AddCollectionField(ctx, &milvuspb.AddCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: getFieldSchema("field3"),
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertFieldProperties(t, ctx, core, dbName, collectionName, "field1", "key1", "value1")
assertFieldExists(t, ctx, core, dbName, collectionName, "field3", 102)
assertSchemaVersion(t, ctx, core, dbName, collectionName, 2)
}
func getFieldSchema(fieldName string) []byte {
fieldSchema := &schemapb.FieldSchema{
Name: fieldName,
DataType: schemapb.DataType_Int64,
}
schemaBytes, _ := proto.Marshal(fieldSchema)
return schemaBytes
}
func assertFieldExists(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, fieldName string, fieldID int64) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
for _, field := range coll.Fields {
if field.Name == fieldName {
require.Equal(t, field.FieldID, fieldID)
return
}
}
require.Fail(t, "field not found")
}

View File

@ -0,0 +1,97 @@
package rootcoord
import (
"context"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func (c *Core) broadcastAlterCollectionV2ForAlterCollectionField(ctx context.Context, req *milvuspb.AlterCollectionFieldRequest) error {
broadcaster, err := startBroadcastWithCollectionLock(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
coll, err := c.meta.GetCollectionByName(ctx, req.GetDbName(), req.GetCollectionName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
oldFieldProperties, err := GetFieldProperties(coll, req.GetFieldName())
if err != nil {
return err
}
oldFieldPropertiesMap := common.CloneKeyValuePairs(oldFieldProperties).ToMap()
for _, prop := range req.GetProperties() {
oldFieldPropertiesMap[prop.GetKey()] = prop.GetValue()
}
for _, deleteKey := range req.GetDeleteKeys() {
delete(oldFieldPropertiesMap, deleteKey)
}
newFieldProperties := common.NewKeyValuePairs(oldFieldPropertiesMap)
if newFieldProperties.Equal(oldFieldProperties) {
// if there's no change, return nil directly to promise idempotent.
return errIgnoredAlterCollection
}
// build new collection schema.
schema := &schemapb.CollectionSchema{
Name: coll.Name,
Description: coll.Description,
AutoID: coll.AutoID,
Fields: model.MarshalFieldModels(coll.Fields),
StructArrayFields: model.MarshalStructArrayFieldModels(coll.StructArrayFields),
Functions: model.MarshalFunctionModels(coll.Functions),
EnableDynamicField: coll.EnableDynamicField,
Properties: coll.Properties,
Version: coll.SchemaVersion + 1,
}
for _, field := range schema.Fields {
if field.Name == req.GetFieldName() {
field.TypeParams = newFieldProperties
break
}
}
cacheExpirations, err := c.getCacheExpireForCollection(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
header := &messagespb.AlterCollectionMessageHeader{
DbId: coll.DBID,
CollectionId: coll.CollectionID,
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{message.FieldMaskCollectionSchema},
},
CacheExpirations: cacheExpirations,
}
body := &messagespb.AlterCollectionMessageBody{
Updates: &messagespb.AlterCollectionMessageUpdates{
Schema: schema,
},
}
channels := make([]string, 0, len(coll.VirtualChannelNames)+1)
channels = append(channels, streaming.WAL().ControlChannel())
channels = append(channels, coll.VirtualChannelNames...)
msg := message.NewAlterCollectionMessageBuilderV2().
WithHeader(header).
WithBody(body).
WithBroadcast(channels).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,155 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksAlterCollectionField(t *testing.T) {
core := initStreamingSystemAndCore(t)
ctx := context.Background()
dbName := "testDB" + funcutil.RandomString(10)
collectionName := "testCollection" + funcutil.RandomString(10)
fieldName := "field1"
// database not found
resp, err := core.AlterCollectionField(ctx, &milvuspb.AlterCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
FieldName: fieldName,
Properties: []*commonpb.KeyValuePair{{Key: "key1", Value: "value1"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrDatabaseNotFound)
// collection not found
resp, err = core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: util.DefaultDBName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
resp, err = core.AlterCollectionField(ctx, &milvuspb.AlterCollectionFieldRequest{
DbName: util.DefaultDBName,
CollectionName: collectionName,
FieldName: fieldName,
Properties: []*commonpb.KeyValuePair{{Key: "key1", Value: "value1"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrCollectionNotFound)
// atler collection field field not found
createCollectionForTest(t, ctx, core, dbName, collectionName)
resp, err = core.AlterCollectionField(ctx, &milvuspb.AlterCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
FieldName: fieldName + "2",
Properties: []*commonpb.KeyValuePair{{Key: "key1", Value: "value1"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
resp, err = core.AlterCollectionField(ctx, &milvuspb.AlterCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
FieldName: fieldName,
Properties: []*commonpb.KeyValuePair{{Key: "key1", Value: "value1"}},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertFieldProperties(t, ctx, core, dbName, collectionName, "field1", "key1", "value1")
assertSchemaVersion(t, ctx, core, dbName, collectionName, 1)
resp, err = core.AlterCollectionField(ctx, &milvuspb.AlterCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
FieldName: fieldName,
Properties: []*commonpb.KeyValuePair{
{Key: "key1", Value: "value1"},
{Key: "key2", Value: "value2"},
},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertFieldProperties(t, ctx, core, dbName, collectionName, fieldName, "key2", "value2")
assertFieldProperties(t, ctx, core, dbName, collectionName, fieldName, "key1", "value1")
assertSchemaVersion(t, ctx, core, dbName, collectionName, 2)
// delete key and add new key
resp, err = core.AlterCollectionField(ctx, &milvuspb.AlterCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
FieldName: fieldName,
Properties: []*commonpb.KeyValuePair{
{Key: "key1", Value: "value1"},
{Key: "key3", Value: "value3"},
},
DeleteKeys: []string{"key2"},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertFieldProperties(t, ctx, core, dbName, collectionName, fieldName, "key1", "value1")
assertFieldProperties(t, ctx, core, dbName, collectionName, fieldName, "key3", "value3")
assertFieldPropertiesNotFound(t, ctx, core, dbName, collectionName, fieldName, "key2")
assertSchemaVersion(t, ctx, core, dbName, collectionName, 3)
// idempotency check
resp, err = core.AlterCollectionField(ctx, &milvuspb.AlterCollectionFieldRequest{
DbName: dbName,
CollectionName: collectionName,
FieldName: fieldName,
Properties: []*commonpb.KeyValuePair{{Key: "key1", Value: "value1"}},
DeleteKeys: []string{"key2"},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertFieldProperties(t, ctx, core, dbName, collectionName, fieldName, "key1", "value1")
assertFieldPropertiesNotFound(t, ctx, core, dbName, collectionName, fieldName, "key2")
assertSchemaVersion(t, ctx, core, dbName, collectionName, 3)
}
func assertFieldPropertiesNotFound(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, fieldName string, key string) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
for _, field := range coll.Fields {
if field.Name == fieldName {
for _, property := range field.TypeParams {
if property.Key == key {
require.Fail(t, "property found", "property %s found in field %s", key, fieldName)
}
}
}
}
}
func assertFieldProperties(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, fieldName string, key string, val string) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
for _, field := range coll.Fields {
if field.Name == fieldName {
for _, property := range field.TypeParams {
if property.Key == key {
require.Equal(t, val, property.Value)
return
}
}
}
}
}

View File

@ -0,0 +1,129 @@
package rootcoord
import (
"context"
"fmt"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func (c *Core) broadcastAlterCollectionForRenameCollection(ctx context.Context, req *milvuspb.RenameCollectionRequest) error {
if req.DbName == "" {
req.DbName = util.DefaultDBName
}
if req.NewDBName == "" {
req.NewDBName = req.DbName
}
if req.NewName == "" {
return merr.WrapErrParameterInvalidMsg("new collection name should not be empty")
}
if req.OldName == "" {
return merr.WrapErrParameterInvalidMsg("old collection name should not be empty")
}
if req.DbName == req.NewDBName && req.OldName == req.NewName {
// no-op here.
return merr.WrapErrParameterInvalidMsg("collection name or database name should be different")
}
// StartBroadcastWithResourceKeys will deduplicate the resource keys itself, so it's safe to add all the resource keys here.
rks := []message.ResourceKey{
message.NewSharedDBNameResourceKey(req.GetNewDBName()),
message.NewSharedDBNameResourceKey(req.GetDbName()),
message.NewExclusiveCollectionNameResourceKey(req.GetDbName(), req.GetOldName()),
message.NewExclusiveCollectionNameResourceKey(req.GetNewDBName(), req.GetNewName()),
}
broadcaster, err := broadcast.StartBroadcastWithResourceKeys(ctx, rks...)
if err != nil {
return err
}
defer broadcaster.Close()
if err := c.validateEncryption(ctx, req.GetDbName(), req.GetNewDBName()); err != nil {
return err
}
if err := c.meta.CheckIfCollectionRenamable(ctx, req.GetDbName(), req.GetOldName(), req.GetNewDBName(), req.GetNewName()); err != nil {
return err
}
newDB, err := c.meta.GetDatabaseByName(ctx, req.GetNewDBName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
coll, err := c.meta.GetCollectionByName(ctx, req.GetDbName(), req.GetOldName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
updateMask := &fieldmaskpb.FieldMask{
Paths: []string{},
}
updates := &message.AlterCollectionMessageUpdates{}
if req.GetNewDBName() != req.GetDbName() {
updates.DbName = newDB.Name
updates.DbId = newDB.ID
updateMask.Paths = append(updateMask.Paths, message.FieldMaskDB)
}
if req.GetNewName() != req.GetOldName() {
updates.CollectionName = req.GetNewName()
updateMask.Paths = append(updateMask.Paths, message.FieldMaskCollectionName)
}
channels := make([]string, 0, len(coll.VirtualChannelNames)+1)
channels = append(channels, streaming.WAL().ControlChannel())
channels = append(channels, coll.VirtualChannelNames...)
cacheExpirations, err := c.getCacheExpireForCollection(ctx, req.GetDbName(), req.GetOldName())
if err != nil {
return err
}
msg := message.NewAlterCollectionMessageBuilderV2().
WithHeader(&message.AlterCollectionMessageHeader{
DbId: coll.DBID,
CollectionId: coll.CollectionID,
UpdateMask: updateMask,
CacheExpirations: cacheExpirations,
}).
WithBody(&message.AlterCollectionMessageBody{
Updates: updates,
}).
WithBroadcast(channels).
MustBuildBroadcast()
_, err = broadcaster.Broadcast(ctx, msg)
return err
}
func (c *Core) validateEncryption(ctx context.Context, oldDBName string, newDBName string) error {
if oldDBName == newDBName {
return nil
}
// Check if renaming across databases with encryption enabled
// old and new DB names are filled in Prepare, shouldn't be empty here
originalDB, err := c.meta.GetDatabaseByName(ctx, oldDBName, typeutil.MaxTimestamp)
if err != nil {
return fmt.Errorf("failed to get original database: %w", err)
}
targetDB, err := c.meta.GetDatabaseByName(ctx, newDBName, typeutil.MaxTimestamp)
if err != nil {
return fmt.Errorf("target database %s not found: %w", newDBName, err)
}
// Check if either database has encryption enabled
if hookutil.IsDBEncryptionEnabled(originalDB.Properties) || hookutil.IsDBEncryptionEnabled(targetDB.Properties) {
return fmt.Errorf("deny to change collection databases due to at least one database enabled encryption, original DB: %s, target DB: %s", oldDBName, newDBName)
}
return nil
}

View File

@ -0,0 +1,159 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksAlterCollectionName(t *testing.T) {
core := initStreamingSystemAndCore(t)
ctx := context.Background()
dbName := "testDB" + funcutil.RandomString(10)
collectionName := "testCollection" + funcutil.RandomString(10)
newDbName := "testDatabaseNew" + funcutil.RandomString(10)
newCollectionName := "testCollectionNew" + funcutil.RandomString(10)
resp, err := core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
OldName: collectionName,
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: dbName,
OldName: collectionName,
NewDBName: dbName,
NewName: collectionName,
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: dbName,
OldName: collectionName,
NewDBName: newDbName,
NewName: newCollectionName,
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrDatabaseNotFound)
resp, err = core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: util.DefaultDBName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: util.DefaultDBName,
OldName: collectionName,
NewName: newCollectionName,
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrCollectionNotFound)
// rename a collection success in one database.
createCollectionForTest(t, ctx, core, dbName, collectionName)
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: dbName,
OldName: collectionName,
NewName: newCollectionName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.ErrorIs(t, err, merr.ErrCollectionNotFound)
require.Nil(t, coll)
coll, err = core.meta.GetCollectionByName(ctx, dbName, newCollectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.NotNil(t, coll)
require.Equal(t, coll.Name, newCollectionName)
// rename a collection success cross database.
resp, err = core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: newDbName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: dbName,
OldName: newCollectionName,
NewDBName: newDbName,
NewName: newCollectionName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
coll, err = core.meta.GetCollectionByName(ctx, dbName, newCollectionName, typeutil.MaxTimestamp)
require.ErrorIs(t, err, merr.ErrCollectionNotFound)
require.Nil(t, coll)
coll, err = core.meta.GetCollectionByName(ctx, newDbName, newCollectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.NotNil(t, coll)
require.Equal(t, coll.Name, newCollectionName)
require.Equal(t, coll.DBName, newDbName)
// rename a collection and database at same time.
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: newDbName,
OldName: newCollectionName,
NewDBName: dbName,
NewName: collectionName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
coll, err = core.meta.GetCollectionByName(ctx, newDbName, newCollectionName, typeutil.MaxTimestamp)
require.ErrorIs(t, err, merr.ErrCollectionNotFound)
require.Nil(t, coll)
coll, err = core.meta.GetCollectionByName(ctx, newDbName, collectionName, typeutil.MaxTimestamp)
require.ErrorIs(t, err, merr.ErrCollectionNotFound)
require.Nil(t, coll)
coll, err = core.meta.GetCollectionByName(ctx, dbName, newCollectionName, typeutil.MaxTimestamp)
require.ErrorIs(t, err, merr.ErrCollectionNotFound)
require.Nil(t, coll)
coll, err = core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.NotNil(t, coll)
require.Equal(t, coll.Name, collectionName)
require.Equal(t, coll.DBName, dbName)
require.Equal(t, coll.SchemaVersion, int32(0)) // SchemaVersion should not be changed with rename.
resp, err = core.CreateAlias(ctx, &milvuspb.CreateAliasRequest{
DbName: dbName,
CollectionName: collectionName,
Alias: newCollectionName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
// rename a collection has aliases cross database should fail.
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: dbName,
OldName: collectionName,
NewDBName: newDbName,
NewName: collectionName,
})
require.Error(t, merr.CheckRPCCall(resp, err))
// rename a collection to a duplicated name should fail.
resp, err = core.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{
DbName: dbName,
OldName: collectionName,
NewName: newCollectionName,
})
require.Error(t, merr.CheckRPCCall(resp, err))
}

View File

@ -0,0 +1,294 @@
package rootcoord
import (
"context"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/ce"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// broadcastAlterCollectionForAlterCollection broadcasts the put collection message for alter collection.
func (c *Core) broadcastAlterCollectionForAlterCollection(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
if req.GetCollectionName() == "" {
return merr.WrapErrParameterInvalidMsg("alter collection failed, collection name does not exists")
}
if len(req.GetProperties()) == 0 && len(req.GetDeleteKeys()) == 0 {
return merr.WrapErrParameterInvalidMsg("no properties or delete keys provided")
}
if len(req.GetProperties()) > 0 && len(req.GetDeleteKeys()) > 0 {
return merr.WrapErrParameterInvalidMsg("can not provide properties and deletekeys at the same time")
}
if hookutil.ContainsCipherProperties(req.GetProperties(), req.GetDeleteKeys()) {
return merr.WrapErrParameterInvalidMsg("can not alter cipher related properties")
}
if funcutil.SliceContain(req.GetDeleteKeys(), common.EnableDynamicSchemaKey) {
return merr.WrapErrParameterInvalidMsg("cannot delete key %s, dynamic field schema could support set to true/false", common.EnableDynamicSchemaKey)
}
isEnableDynamicSchema, targetValue, err := common.IsEnableDynamicSchema(req.GetProperties())
if err != nil {
return merr.WrapErrParameterInvalidMsg("invalid dynamic schema property value: %s", req.GetProperties()[0].GetValue())
}
if isEnableDynamicSchema {
// if there's dynamic schema property, it will add a new dynamic field into the collection.
// the property cannot be seen at collection properties, only add a new field into the collection.
return c.broadcastAlterCollectionForAlterDynamicField(ctx, req, targetValue)
}
broadcaster, err := startBroadcastWithCollectionLock(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
// check if the collection exists
coll, err := c.meta.GetCollectionByName(ctx, req.GetDbName(), req.GetCollectionName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
cacheExpirations, err := c.getCacheExpireForCollection(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
header := &messagespb.AlterCollectionMessageHeader{
DbId: coll.DBID,
CollectionId: coll.CollectionID,
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{},
},
CacheExpirations: cacheExpirations,
}
udpates := &messagespb.AlterCollectionMessageUpdates{}
// Apply the properties to override the existing properties.
newProperties := common.CloneKeyValuePairs(coll.Properties).ToMap()
for _, prop := range req.GetProperties() {
switch prop.GetKey() {
case common.CollectionDescription:
if prop.GetValue() != coll.Description {
udpates.Description = prop.GetValue()
header.UpdateMask.Paths = append(header.UpdateMask.Paths, message.FieldMaskCollectionDescription)
}
case common.ConsistencyLevel:
if lv, ok := unmarshalConsistencyLevel(prop.GetValue()); ok && lv != coll.ConsistencyLevel {
udpates.ConsistencyLevel = lv
header.UpdateMask.Paths = append(header.UpdateMask.Paths, message.FieldMaskCollectionConsistencyLevel)
}
default:
newProperties[prop.GetKey()] = prop.GetValue()
}
}
for _, deleteKey := range req.GetDeleteKeys() {
delete(newProperties, deleteKey)
}
// Check if the properties are changed.
newPropsKeyValuePairs := common.NewKeyValuePairs(newProperties)
if !newPropsKeyValuePairs.Equal(coll.Properties) {
udpates.Properties = newPropsKeyValuePairs
header.UpdateMask.Paths = append(header.UpdateMask.Paths, message.FieldMaskCollectionProperties)
}
// if there's no change, return nil directly to promise idempotent.
if len(header.UpdateMask.Paths) == 0 {
return errIgnoredAlterCollection
}
// fill the put load config if rg or replica number is changed.
udpates.AlterLoadConfig = c.getAlterLoadConfigOfAlterCollection(coll.Properties, udpates.Properties)
channels := make([]string, 0, len(coll.VirtualChannelNames)+1)
channels = append(channels, streaming.WAL().ControlChannel())
channels = append(channels, coll.VirtualChannelNames...)
msg := message.NewAlterCollectionMessageBuilderV2().
WithHeader(header).
WithBody(&messagespb.AlterCollectionMessageBody{
Updates: udpates,
}).
WithBroadcast(channels).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
return err
}
return nil
}
// broadcastAlterCollectionForAlterDynamicField broadcasts the put collection message for alter dynamic field.
func (c *Core) broadcastAlterCollectionForAlterDynamicField(ctx context.Context, req *milvuspb.AlterCollectionRequest, targetValue bool) error {
if len(req.GetProperties()) != 1 {
return merr.WrapErrParameterInvalidMsg("cannot alter dynamic schema with other properties at the same time")
}
broadcaster, err := startBroadcastWithCollectionLock(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
coll, err := c.meta.GetCollectionByName(ctx, req.GetDbName(), req.GetCollectionName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
// return nil for no-op
if coll.EnableDynamicField == targetValue {
return errIgnoredAlterCollection
}
// not support disabling since remove field not support yet.
if !targetValue {
return merr.WrapErrParameterInvalidMsg("dynamic schema cannot supported to be disabled")
}
// convert to add $meta json field, nullable, default value `{}`
fieldSchema := &schemapb.FieldSchema{
Name: common.MetaFieldName,
DataType: schemapb.DataType_JSON,
IsDynamic: true,
Nullable: true,
DefaultValue: &schemapb.ValueField{
Data: &schemapb.ValueField_BytesData{
BytesData: []byte("{}"),
},
},
}
if err := checkFieldSchema([]*schemapb.FieldSchema{fieldSchema}); err != nil {
return err
}
fieldSchema.FieldID = nextFieldID(coll)
schema := &schemapb.CollectionSchema{
Name: coll.Name,
Description: coll.Description,
AutoID: coll.AutoID,
Fields: model.MarshalFieldModels(coll.Fields),
StructArrayFields: model.MarshalStructArrayFieldModels(coll.StructArrayFields),
Functions: model.MarshalFunctionModels(coll.Functions),
EnableDynamicField: targetValue,
Properties: coll.Properties,
Version: coll.SchemaVersion + 1,
}
schema.Fields = append(schema.Fields, fieldSchema)
channels := make([]string, 0, len(coll.VirtualChannelNames)+1)
channels = append(channels, streaming.WAL().ControlChannel())
channels = append(channels, coll.VirtualChannelNames...)
cacheExpirations, err := c.getCacheExpireForCollection(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
// broadcast the put collection v2 message.
msg := message.NewAlterCollectionMessageBuilderV2().
WithHeader(&messagespb.AlterCollectionMessageHeader{
DbId: coll.DBID,
CollectionId: coll.CollectionID,
UpdateMask: &fieldmaskpb.FieldMask{
Paths: []string{message.FieldMaskCollectionSchema},
},
CacheExpirations: cacheExpirations,
}).
WithBody(&messagespb.AlterCollectionMessageBody{
Updates: &messagespb.AlterCollectionMessageUpdates{
Schema: schema,
},
}).
WithBroadcast(channels).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
return err
}
return nil
}
// getCacheExpireForCollection gets the cache expirations for collection.
func (c *Core) getCacheExpireForCollection(ctx context.Context, dbName string, collectionName string) (*message.CacheExpirations, error) {
coll, err := c.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
if err != nil {
return nil, err
}
aliases, err := c.meta.ListAliases(ctx, dbName, collectionName, typeutil.MaxTimestamp)
if err != nil {
return nil, err
}
builder := ce.NewBuilder()
builder.WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(dbName),
ce.OptLPCMCollectionName(collectionName),
ce.OptLPCMCollectionID(coll.CollectionID),
ce.OptLPCMMsgType(commonpb.MsgType_AlterCollection),
)
for _, alias := range aliases {
builder.WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(dbName),
ce.OptLPCMCollectionName(alias),
ce.OptLPCMCollectionID(coll.CollectionID),
ce.OptLPCMMsgType(commonpb.MsgType_AlterAlias),
)
}
return builder.Build(), nil
}
// getAlterLoadConfigOfAlterCollection gets the alter load config of alter collection.
func (c *Core) getAlterLoadConfigOfAlterCollection(oldProps []*commonpb.KeyValuePair, newProps []*commonpb.KeyValuePair) *message.AlterLoadConfigOfAlterCollection {
oldReplicaNumber, _ := common.CollectionLevelReplicaNumber(oldProps)
oldResourceGroups, _ := common.CollectionLevelResourceGroups(oldProps)
newReplicaNumber, _ := common.CollectionLevelReplicaNumber(newProps)
newResourceGroups, _ := common.CollectionLevelResourceGroups(newProps)
left, right := lo.Difference(oldResourceGroups, newResourceGroups)
rgChanged := len(left) > 0 || len(right) > 0
replicaChanged := oldReplicaNumber != newReplicaNumber
if !replicaChanged && !rgChanged {
return nil
}
return &message.AlterLoadConfigOfAlterCollection{
ReplicaNumber: int32(newReplicaNumber),
ResourceGroups: newResourceGroups,
}
}
func (c *DDLCallback) alterCollectionV2AckCallback(ctx context.Context, result message.BroadcastResultAlterCollectionMessageV2) error {
header := result.Message.Header()
body := result.Message.MustBody()
if err := c.meta.AlterCollection(ctx, result); err != nil {
if errors.Is(err, errAlterCollectionNotFound) {
log.Ctx(ctx).Warn("alter a non-existent collection, ignore it", log.FieldMessage(result.Message))
return nil
}
return errors.Wrap(err, "failed to alter collection")
}
if body.Updates.AlterLoadConfig != nil {
resp, err := c.mixCoord.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{
CollectionIDs: []int64{header.CollectionId},
ReplicaNumber: body.Updates.AlterLoadConfig.ReplicaNumber,
ResourceGroups: body.Updates.AlterLoadConfig.ResourceGroups,
})
if err := merr.CheckRPCCall(resp, err); err != nil {
return errors.Wrap(err, "failed to update load config")
}
}
if err := c.broker.BroadcastAlteredCollection(ctx, header.CollectionId); err != nil {
return errors.Wrap(err, "failed to broadcast altered collection")
}
return c.ExpireCaches(ctx, header, result.GetControlChannelResult().TimeTick)
}

View File

@ -0,0 +1,302 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksAlterCollectionProperties(t *testing.T) {
core := initStreamingSystemAndCore(t)
ctx := context.Background()
dbName := "testDB" + funcutil.RandomString(10)
collectionName := "testCollection" + funcutil.RandomString(10)
// Cannot alter collection with empty properties and delete keys.
resp, err := core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// Cannot alter collection properties with delete keys at same time.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.CollectionReplicaNumber, Value: "1"}},
DeleteKeys: []string{common.CollectionReplicaNumber},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// hook related properties are not allowed to be altered.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: hookutil.EncryptionEnabledKey, Value: "1"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// Alter a database that does not exist should return error.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
DeleteKeys: []string{common.CollectionReplicaNumber},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrDatabaseNotFound)
// Alter a collection that does not exist should return error.
resp, err = core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: util.DefaultDBName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: util.DefaultDBName,
CollectionName: collectionName,
DeleteKeys: []string{common.CollectionReplicaNumber},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrCollectionNotFound)
// atler a property of a collection.
createCollectionAndAliasForTest(t, ctx, core, dbName, collectionName)
assertReplicaNumber(t, ctx, core, dbName, collectionName, 1)
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{
{Key: common.CollectionReplicaNumber, Value: "2"},
{Key: common.CollectionResourceGroups, Value: "rg1,rg2"},
},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertReplicaNumber(t, ctx, core, dbName, collectionName, 2)
assertResourceGroups(t, ctx, core, dbName, collectionName, []string{"rg1", "rg2"})
// delete a property of a collection.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
DeleteKeys: []string{common.CollectionReplicaNumber, common.CollectionResourceGroups},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertReplicaNumber(t, ctx, core, dbName, collectionName, 0)
assertResourceGroups(t, ctx, core, dbName, collectionName, []string{})
// alter consistency level and description of a collection.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{
{Key: common.ConsistencyLevel, Value: commonpb.ConsistencyLevel_Eventually.String()},
{Key: common.CollectionDescription, Value: "description2"},
},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertConsistencyLevel(t, ctx, core, dbName, collectionName, commonpb.ConsistencyLevel_Eventually)
assertDescription(t, ctx, core, dbName, collectionName, "description2")
// alter collection should be idempotent.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{
{Key: common.ConsistencyLevel, Value: commonpb.ConsistencyLevel_Eventually.String()},
{Key: common.CollectionDescription, Value: "description2"},
},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertConsistencyLevel(t, ctx, core, dbName, collectionName, commonpb.ConsistencyLevel_Eventually)
assertDescription(t, ctx, core, dbName, collectionName, "description2")
assertSchemaVersion(t, ctx, core, dbName, collectionName, 0) // schema version should not be changed with alter collection properties.
// update dynamic schema property with other properties should return error.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.EnableDynamicSchemaKey, Value: "true"}, {Key: common.CollectionReplicaNumber, Value: "1"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
}
func TestDDLCallbacksAlterCollectionPropertiesForDynamicField(t *testing.T) {
core := initStreamingSystemAndCore(t)
ctx := context.Background()
dbName := "testDB" + funcutil.RandomString(10)
collectionName := "testCollection" + funcutil.RandomString(10)
createCollectionAndAliasForTest(t, ctx, core, dbName, collectionName)
// update dynamic schema property with other properties should return error.
resp, err := core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.EnableDynamicSchemaKey, Value: "true"}, {Key: common.CollectionReplicaNumber, Value: "1"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// update dynamic schema property with invalid value should return error.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.EnableDynamicSchemaKey, Value: "123123"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// update dynamic schema property with other properties should return error.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.EnableDynamicSchemaKey, Value: "true"}},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertDynamicSchema(t, ctx, core, dbName, collectionName, true)
assertSchemaVersion(t, ctx, core, dbName, collectionName, 1) // add dynamic field should increment schema version.
// update dynamic schema property with other properties should be idempotent.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.EnableDynamicSchemaKey, Value: "true"}},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertDynamicSchema(t, ctx, core, dbName, collectionName, true)
assertSchemaVersion(t, ctx, core, dbName, collectionName, 1)
// disable dynamic schema property should return error.
resp, err = core.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.EnableDynamicSchemaKey, Value: "false"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
}
func createCollectionForTest(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string) {
resp, err := core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: dbName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
testSchema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "description",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "field1",
DataType: schemapb.DataType_Int64,
},
},
}
schemaBytes, err := proto.Marshal(testSchema)
require.NoError(t, err)
resp, err = core.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.CollectionReplicaNumber, Value: "1"}},
Schema: schemaBytes,
ConsistencyLevel: commonpb.ConsistencyLevel_Bounded,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertReplicaNumber(t, ctx, core, dbName, collectionName, 1)
assertConsistencyLevel(t, ctx, core, dbName, collectionName, commonpb.ConsistencyLevel_Bounded)
assertDescription(t, ctx, core, dbName, collectionName, "description")
assertSchemaVersion(t, ctx, core, dbName, collectionName, 0)
}
func createCollectionAndAliasForTest(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string) {
createCollectionForTest(t, ctx, core, dbName, collectionName)
// add an alias to the collection.
aliasName := collectionName + "_alias"
resp, err := core.CreateAlias(ctx, &milvuspb.CreateAliasRequest{
DbName: dbName,
CollectionName: collectionName,
Alias: aliasName,
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertReplicaNumber(t, ctx, core, dbName, aliasName, 1)
assertConsistencyLevel(t, ctx, core, dbName, aliasName, commonpb.ConsistencyLevel_Bounded)
assertDescription(t, ctx, core, dbName, aliasName, "description")
}
func assertReplicaNumber(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, replicaNumber int64) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
replicaNum, err := common.CollectionLevelReplicaNumber(coll.Properties)
if replicaNumber == 0 {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, replicaNumber, replicaNum)
}
func assertResourceGroups(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, resourceGroups []string) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
rgs, err := common.CollectionLevelResourceGroups(coll.Properties)
if len(resourceGroups) == 0 {
require.Error(t, err)
return
}
require.NoError(t, err)
require.ElementsMatch(t, resourceGroups, rgs)
}
func assertConsistencyLevel(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, consistencyLevel commonpb.ConsistencyLevel) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.Equal(t, consistencyLevel, coll.ConsistencyLevel)
}
func assertDescription(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, description string) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.Equal(t, description, coll.Description)
}
func assertSchemaVersion(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, schemaVersion int32) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.Equal(t, schemaVersion, coll.SchemaVersion)
}
func assertDynamicSchema(t *testing.T, ctx context.Context, core *Core, dbName string, collectionName string, dynamicSchema bool) {
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.Equal(t, dynamicSchema, coll.EnableDynamicField)
if !dynamicSchema {
return
}
require.Len(t, coll.Fields, 4)
require.True(t, coll.Fields[len(coll.Fields)-1].IsDynamic)
require.Equal(t, coll.Fields[len(coll.Fields)-1].DataType, schemapb.DataType_JSON)
require.Equal(t, coll.Fields[len(coll.Fields)-1].FieldID, int64(101))
}

View File

@ -18,7 +18,6 @@ package rootcoord
import (
"context"
"strings"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
@ -28,7 +27,6 @@ import (
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
@ -39,17 +37,19 @@ import (
)
func (c *Core) broadcastAlterDatabase(ctx context.Context, req *rootcoordpb.AlterDatabaseRequest) error {
req.DbName = strings.TrimSpace(req.DbName)
if req.GetDbName() == "" {
return merr.WrapErrParameterInvalidMsg("alter database failed, database name does not exists")
}
if req.GetProperties() == nil && req.GetDeleteKeys() == nil {
return errors.New("alter database with empty properties and delete keys, expected to set either properties or delete keys")
return merr.WrapErrParameterInvalidMsg("alter database with empty properties and delete keys, expected to set either properties or delete keys")
}
if len(req.GetProperties()) > 0 && len(req.GetDeleteKeys()) > 0 {
return errors.New("alter database cannot modify properties and delete keys at the same time")
return merr.WrapErrParameterInvalidMsg("alter database cannot modify properties and delete keys at the same time")
}
if hookutil.ContainsCipherProperties(req.GetProperties(), req.GetDeleteKeys()) {
return errors.New("can not alter cipher related properties")
return merr.WrapErrParameterInvalidMsg("can not alter cipher related properties")
}
broadcaster, err := startBroadcastWithDatabaseLock(ctx, req.GetDbName())
@ -71,12 +71,16 @@ func (c *Core) broadcastAlterDatabase(ctx context.Context, req *rootcoordpb.Alte
var newProperties []*commonpb.KeyValuePair
if (len(req.GetProperties())) > 0 {
if IsSubsetOfProperties(req.GetProperties(), oldDB.Properties) {
log.Info("skip to alter database due to no changes were detected in the properties")
return nil
// no changes were detected in the properties
return errIgnoredAlterDatabase
}
newProperties = MergeProperties(oldDB.Properties, req.GetProperties())
} else if (len(req.GetDeleteKeys())) > 0 {
newProperties = DeleteProperties(oldDB.Properties, req.GetDeleteKeys())
if len(newProperties) == len(oldDB.Properties) {
// no changes were detected in the properties
return errIgnoredAlterDatabase
}
}
msg := message.NewAlterDatabaseMessageBuilderV2().
@ -94,7 +98,7 @@ func (c *Core) broadcastAlterDatabase(ctx context.Context, req *rootcoordpb.Alte
return err
}
// getAlterLoadConfigOfAlterDatabase gets the put load config of put database.
// getAlterLoadConfigOfAlterDatabase gets the alter load config of alter database.
func (c *Core) getAlterLoadConfigOfAlterDatabase(ctx context.Context, dbName string, oldProps []*commonpb.KeyValuePair, newProps []*commonpb.KeyValuePair) (*message.AlterLoadConfigOfAlterDatabase, error) {
oldReplicaNumber, _ := common.DatabaseLevelReplicaNumber(oldProps)
oldResourceGroups, _ := common.DatabaseLevelResourceGroups(oldProps)
@ -129,24 +133,22 @@ func (c *DDLCallback) alterDatabaseV1AckCallback(ctx context.Context, result mes
if err := c.meta.AlterDatabase(ctx, db, result.GetControlChannelResult().TimeTick); err != nil {
return errors.Wrap(err, "failed to alter database")
}
if err := c.ExpireCaches(ctx, ce.NewBuilder().
WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(header.DbName),
ce.OptLPCMMsgType(commonpb.MsgType_AlterDatabase),
),
result.GetControlChannelResult().TimeTick); err != nil {
return errors.Wrap(err, "failed to expire caches")
}
if body.AlterLoadConfig != nil {
// TODO: should replaced with calling AlterLoadConfig message ack callback.
resp, err := c.mixCoord.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{
CollectionIDs: body.AlterLoadConfig.CollectionIds,
ReplicaNumber: body.AlterLoadConfig.ReplicaNumber,
ResourceGroups: body.AlterLoadConfig.ResourceGroups,
})
return merr.CheckRPCCall(resp, err)
if err := merr.CheckRPCCall(resp, err); err != nil {
return errors.Wrap(err, "failed to update load config")
}
}
return nil
return c.ExpireCaches(ctx, ce.NewBuilder().
WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(header.DbName),
ce.OptLPCMMsgType(commonpb.MsgType_AlterDatabase),
),
result.GetControlChannelResult().TimeTick)
}
func MergeProperties(oldProps, updatedProps []*commonpb.KeyValuePair) []*commonpb.KeyValuePair {

View File

@ -0,0 +1,140 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksAlterDatabase(t *testing.T) {
core := initStreamingSystemAndCore(t)
ctx := context.Background()
dbName := "testDB" + funcutil.RandomString(10)
// Cannot alter collection with empty properties and delete keys.
resp, err := core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// Cannot alter collection properties with delete keys at same time.
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
Properties: []*commonpb.KeyValuePair{{Key: common.DatabaseReplicaNumber, Value: "1"}},
DeleteKeys: []string{common.DatabaseReplicaNumber},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// hook related properties are not allowed to be altered.
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
Properties: []*commonpb.KeyValuePair{{Key: hookutil.EncryptionEnabledKey, Value: "1"}},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
// Alter a database that does not exist should return error.
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
DeleteKeys: []string{common.DatabaseReplicaNumber},
})
require.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrDatabaseNotFound)
resp, err = core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: dbName,
Properties: []*commonpb.KeyValuePair{{Key: common.DatabaseReplicaNumber, Value: "1"}},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertDatabaseReplicaNumber(t, ctx, core, dbName, 1)
// alter a property of a database.
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
Properties: []*commonpb.KeyValuePair{
{Key: common.DatabaseReplicaNumber, Value: "2"},
{Key: common.DatabaseResourceGroups, Value: "rg1,rg2"},
},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertDatabaseReplicaNumber(t, ctx, core, dbName, 2)
assertDatabaseResourceGroups(t, ctx, core, dbName, []string{"rg1", "rg2"})
// alter a property of a database should be idempotent.
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
Properties: []*commonpb.KeyValuePair{
{Key: common.DatabaseReplicaNumber, Value: "2"},
{Key: common.DatabaseResourceGroups, Value: "rg1,rg2"},
},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertDatabaseReplicaNumber(t, ctx, core, dbName, 2)
assertDatabaseResourceGroups(t, ctx, core, dbName, []string{"rg1", "rg2"})
// delete a property of a database.
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
DeleteKeys: []string{common.DatabaseReplicaNumber, common.DatabaseResourceGroups},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertDatabaseReplicaNumber(t, ctx, core, dbName, 0)
assertDatabaseResourceGroups(t, ctx, core, dbName, []string{})
// delete a property of a collection should be idempotent.
resp, err = core.AlterDatabase(ctx, &rootcoordpb.AlterDatabaseRequest{
DbName: dbName,
DeleteKeys: []string{common.DatabaseReplicaNumber, common.DatabaseResourceGroups},
})
require.NoError(t, merr.CheckRPCCall(resp, err))
assertDatabaseReplicaNumber(t, ctx, core, dbName, 0)
assertDatabaseResourceGroups(t, ctx, core, dbName, []string{})
}
func assertDatabaseReplicaNumber(t *testing.T, ctx context.Context, core *Core, dbName string, replicaNumber int64) {
db, err := core.meta.GetDatabaseByName(ctx, dbName, typeutil.MaxTimestamp)
require.NoError(t, err)
replicaNum, err := common.DatabaseLevelReplicaNumber(db.Properties)
if replicaNumber == 0 {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, replicaNumber, replicaNum)
}
func assertDatabaseResourceGroups(t *testing.T, ctx context.Context, core *Core, dbName string, resourceGroups []string) {
db, err := core.meta.GetDatabaseByName(ctx, dbName, typeutil.MaxTimestamp)
require.NoError(t, err)
rgs, err := common.DatabaseLevelResourceGroups(db.Properties)
if len(resourceGroups) == 0 {
require.Error(t, err)
return
}
require.NoError(t, err)
require.ElementsMatch(t, resourceGroups, rgs)
}

View File

@ -52,10 +52,14 @@ import (
var (
errIgnoredAlterAlias = errors.New("ignored alter alias") // alias already created on current collection, so it can be ignored.
errIgnoredAlterCollection = errors.New("ignored alter collection") // collection already created, so it can be ignored.
errIgnoredAlterDatabase = errors.New("ignored alter database") // database already created, so it can be ignored.
errIgnoredCreateCollection = errors.New("ignored create collection") // create collection with same schema, so it can be ignored.
errIgnoerdCreatePartition = errors.New("ignored create partition") // partition is already exist, so it can be ignored.
errIgnoredDropCollection = errors.New("ignored drop collection") // drop collection or database not found, so it can be ignored.
errIgnoredDropPartition = errors.New("ignored drop partition") // drop partition not found, so it can be ignored.
errAlterCollectionNotFound = errors.New("alter collection not found") // alter collection not found, so it can be ignored.
)
type MetaTableChecker interface {
@ -108,8 +112,8 @@ type IMetaTable interface {
DescribeAlias(ctx context.Context, dbName string, alias string, ts Timestamp) (string, error)
ListAliases(ctx context.Context, dbName string, collectionName string, ts Timestamp) ([]string, error)
AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp, fieldModify bool) error
RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error
AlterCollection(ctx context.Context, result message.BroadcastResultAlterCollectionMessageV2) error
CheckIfCollectionRenamable(ctx context.Context, dbName string, oldName string, newDBName string, newName string) error
GetGeneralCount(ctx context.Context) int
// TODO: it'll be a big cost if we handle the time travel logic, since we should always list all aliases in catalog.
@ -349,6 +353,7 @@ func (mt *MetaTable) CheckIfDatabaseCreatable(ctx context.Context, req *milvuspb
defer mt.ddLock.RUnlock()
if _, ok := mt.dbName2Meta[dbName]; ok || mt.aliases.exist(dbName) || mt.names.exist(dbName) {
// TODO: idempotency check here.
return fmt.Errorf("database already exist: %s", dbName)
}
@ -551,6 +556,8 @@ func (mt *MetaTable) DropCollection(ctx context.Context, collectionID UniqueID,
clone := coll.Clone()
clone.State = pb.CollectionState_CollectionDropping
clone.UpdateTimestamp = ts
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
if err := mt.catalog.AlterCollection(ctx1, coll, clone, metastore.MODIFY, ts, false); err != nil {
return err
@ -923,22 +930,56 @@ func (mt *MetaTable) ListCollectionPhysicalChannels(ctx context.Context) map[typ
return chanMap
}
func (mt *MetaTable) AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp, fieldModify bool) error {
// AlterCollection is used to alter a collection in the meta table.
func (mt *MetaTable) AlterCollection(ctx context.Context, result message.BroadcastResultAlterCollectionMessageV2) error {
header := result.Message.Header()
body := result.Message.MustBody()
mt.ddLock.Lock()
defer mt.ddLock.Unlock()
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
if err := mt.catalog.AlterCollection(ctx1, oldColl, newColl, metastore.MODIFY, ts, fieldModify); err != nil {
return err
coll, ok := mt.collID2Meta[header.CollectionId]
if !ok {
// collection not exists, return directly.
return errAlterCollectionNotFound
}
mt.collID2Meta[oldColl.CollectionID] = newColl
log.Ctx(ctx).Info("alter collection finished", zap.Int64("collectionID", oldColl.CollectionID), zap.Uint64("ts", ts))
oldColl := coll.Clone()
newColl := coll.Clone()
newColl.ApplyUpdates(header, body)
fieldModify := false
dbChanged := false
for _, path := range header.UpdateMask.GetPaths() {
switch path {
case message.FieldMaskCollectionSchema:
fieldModify = true
case message.FieldMaskDB:
dbChanged = true
}
}
newColl.UpdateTimestamp = result.GetControlChannelResult().TimeTick
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
if !dbChanged {
if err := mt.catalog.AlterCollection(ctx1, oldColl, newColl, metastore.MODIFY, newColl.UpdateTimestamp, fieldModify); err != nil {
return err
}
} else {
if err := mt.catalog.AlterCollectionDB(ctx1, oldColl, newColl, newColl.UpdateTimestamp); err != nil {
return err
}
}
mt.names.remove(oldColl.DBName, oldColl.Name)
mt.names.insert(newColl.DBName, newColl.Name, newColl.CollectionID)
mt.collID2Meta[header.CollectionId] = newColl
log.Ctx(ctx).Info("alter collection finished", zap.Bool("dbChanged", dbChanged), zap.Int64("collectionID", oldColl.CollectionID), zap.Uint64("ts", newColl.UpdateTimestamp))
return nil
}
func (mt *MetaTable) RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error {
mt.ddLock.Lock()
defer mt.ddLock.Unlock()
func (mt *MetaTable) CheckIfCollectionRenamable(ctx context.Context, dbName string, oldName string, newDBName string, newName string) error {
mt.ddLock.RLock()
defer mt.ddLock.RUnlock()
ctx = contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
log := log.Ctx(ctx).With(
@ -955,8 +996,21 @@ func (mt *MetaTable) RenameCollection(ctx context.Context, dbName string, oldNam
return fmt.Errorf("target database:%s not found", newDBName)
}
// check if new name already belongs to another existing collection
coll, err := mt.getCollectionByNameInternal(ctx, newDBName, newName, ts)
// old collection should not be an alias
_, ok = mt.aliases.get(dbName, oldName)
if ok {
log.Warn("unsupported use a alias to rename collection")
return fmt.Errorf("unsupported use an alias to rename collection, alias:%s", oldName)
}
_, ok = mt.aliases.get(newDBName, newName)
if ok {
log.Warn("cannot rename collection to an existing alias")
return fmt.Errorf("cannot rename collection to an existing alias: %s", newName)
}
// check new collection already exists
coll, err := mt.getCollectionByNameInternal(ctx, newDBName, newName, typeutil.MaxTimestamp)
if coll != nil {
log.Warn("duplicated new collection name, already taken by another collection or alias.")
return fmt.Errorf("duplicated new collection name %s:%s with other collection name or alias", newDBName, newName)
@ -967,34 +1021,17 @@ func (mt *MetaTable) RenameCollection(ctx context.Context, dbName string, oldNam
}
// get old collection meta
oldColl, err := mt.getCollectionByNameInternal(ctx, dbName, oldName, ts)
oldColl, err := mt.getCollectionByNameInternal(ctx, dbName, oldName, typeutil.MaxTimestamp)
if err != nil {
log.Warn("fail to find collection with old name", zap.Error(err))
return err
}
newColl := oldColl.Clone()
newColl.Name = newName
newColl.DBName = dbName
newColl.DBID = targetDB.ID
if oldColl.DBID == newColl.DBID {
if err := mt.catalog.AlterCollection(ctx, oldColl, newColl, metastore.MODIFY, ts, false); err != nil {
log.Warn("alter collection by catalog failed", zap.Error(err))
return err
}
} else {
if err := mt.catalog.AlterCollectionDB(ctx, oldColl, newColl, ts); err != nil {
log.Warn("alter collectionDB by catalog failed", zap.Error(err))
return err
}
// unsupported rename collection while the collection has aliases
aliases := mt.listAliasesByID(oldColl.CollectionID)
if len(aliases) > 0 && oldColl.DBID != targetDB.ID {
return errors.New("fail to rename db name, must drop all aliases of this collection before rename")
}
mt.names.insert(newDBName, newName, oldColl.CollectionID)
mt.names.remove(dbName, oldName)
mt.collID2Meta[oldColl.CollectionID] = newColl
log.Info("rename collection finished")
return nil
}

View File

@ -786,6 +786,7 @@ func TestMetaTable_GetCollectionByName(t *testing.T) {
})
}
/*
func TestMetaTable_AlterCollection(t *testing.T) {
t.Run("alter metastore fail", func(t *testing.T) {
catalog := mocks.NewRootCoordCatalog(t)
@ -852,6 +853,7 @@ func TestMetaTable_AlterCollection(t *testing.T) {
assert.Equal(t, meta.collID2Meta[1], newColl)
})
}
*/
func TestMetaTable_DescribeAlias(t *testing.T) {
t.Run("metatable describe alias ok", func(t *testing.T) {
@ -1513,6 +1515,7 @@ func TestMetaTable_AddPartition(t *testing.T) {
})
}
/*
func TestMetaTable_RenameCollection(t *testing.T) {
t.Run("unsupported use a alias to rename collection", func(t *testing.T) {
meta := &MetaTable{
@ -1795,6 +1798,7 @@ func TestMetaTable_RenameCollection(t *testing.T) {
assert.Equal(t, "new", coll.Name)
})
}
*/
func TestMetaTable_CreateDatabase(t *testing.T) {
db := model.NewDatabase(1, "exist", pb.DatabaseState_DatabaseCreated, nil)

View File

@ -50,7 +50,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -83,7 +82,7 @@ type mockMetaTable struct {
GetCollectionIDByNameFunc func(name string) (UniqueID, error)
GetPartitionByNameFunc func(collID UniqueID, partitionName string, ts Timestamp) (UniqueID, error)
GetCollectionVirtualChannelsFunc func(ctx context.Context, colID int64) []string
AlterCollectionFunc func(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp, fieldModify bool) error
AlterCollectionFunc func(ctx context.Context, result message.BroadcastResultAlterCollectionMessageV2) error
RenameCollectionFunc func(ctx context.Context, oldName string, newName string, ts Timestamp) error
AddCredentialFunc func(ctx context.Context, credInfo *internalpb.CredentialInfo) error
GetCredentialFunc func(ctx context.Context, username string) (*internalpb.CredentialInfo, error)
@ -177,8 +176,8 @@ func (m mockMetaTable) ListAliasesByID(ctx context.Context, collID UniqueID) []s
return m.ListAliasesByIDFunc(ctx, collID)
}
func (m mockMetaTable) AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts Timestamp, fieldModify bool) error {
return m.AlterCollectionFunc(ctx, oldColl, newColl, ts, fieldModify)
func (m mockMetaTable) AlterCollection(ctx context.Context, result message.BroadcastResultAlterCollectionMessageV2) error {
return m.AlterCollectionFunc(ctx, result)
}
func (m *mockMetaTable) RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts Timestamp) error {
@ -418,13 +417,6 @@ func newTestCore(opts ...Opt) *Core {
session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: TestRootCoordID}},
tombstoneSweeper: tombstoneSweeper,
}
executor := newMockStepExecutor()
executor.AddStepsFunc = func(s *stepStack) {
// no schedule, execute directly.
s.Execute(context.Background())
}
executor.StopFunc = func() {}
c.stepExecutor = executor
for _, opt := range opts {
opt(c)
}
@ -625,19 +617,6 @@ func withMixCoord(mixc types.MixCoord) Opt {
}
}
func withUnhealthyMixCoord() Opt {
mixc := &mocks.MixCoord{}
err := errors.New("mock error")
mixc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(
&milvuspb.ComponentStates{
State: &milvuspb.ComponentInfo{StateCode: commonpb.StateCode_Abnormal},
Status: merr.Status(err),
}, retry.Unrecoverable(errors.New("error mock GetComponentStates")),
)
return withMixCoord(mixc)
}
func withInvalidMixCoord() Opt {
mixc := &mocks.MixCoord{}
mixc.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(
@ -763,6 +742,7 @@ func withValidMixCoord() Opt {
merr.Success(), nil,
)
mixc.EXPECT().NotifyDropPartition(mock.Anything, mock.Anything, mock.Anything).Return(nil)
mixc.EXPECT().UpdateLoadConfig(mock.Anything, mock.Anything).Return(merr.Success(), nil)
return withMixCoord(mixc)
}
@ -915,7 +895,7 @@ type mockBroker struct {
DropCollectionIndexFunc func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error
GetSegmentIndexStateFunc func(ctx context.Context, collID UniqueID, indexName string, segIDs []UniqueID) ([]*indexpb.SegmentIndexState, error)
BroadcastAlteredCollectionFunc func(ctx context.Context, req *milvuspb.AlterCollectionRequest) error
BroadcastAlteredCollectionFunc func(ctx context.Context, collectionID UniqueID) error
GCConfirmFunc func(ctx context.Context, collectionID, partitionID UniqueID) bool
}
@ -934,6 +914,9 @@ func newValidMockBroker() *mockBroker {
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return nil
}
broker.BroadcastAlteredCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
return nil
}
return broker
}
@ -969,8 +952,8 @@ func (b mockBroker) GetSegmentIndexState(ctx context.Context, collID UniqueID, i
return b.GetSegmentIndexStateFunc(ctx, collID, indexName, segIDs)
}
func (b mockBroker) BroadcastAlteredCollection(ctx context.Context, req *milvuspb.AlterCollectionRequest) error {
return b.BroadcastAlteredCollectionFunc(ctx, req)
func (b mockBroker) BroadcastAlteredCollection(ctx context.Context, collectionID UniqueID) error {
return b.BroadcastAlteredCollectionFunc(ctx, collectionID)
}
func (b mockBroker) GcConfirm(ctx context.Context, collectionID, partitionID UniqueID) bool {
@ -1090,38 +1073,3 @@ func withDdlTsLockManager(m DdlTsLockManager) Opt {
c.ddlTsLockManager = m
}
}
type mockStepExecutor struct {
StepExecutor
StartFunc func()
StopFunc func()
AddStepsFunc func(s *stepStack)
}
func newMockStepExecutor() *mockStepExecutor {
return &mockStepExecutor{}
}
func (m mockStepExecutor) Start() {
if m.StartFunc != nil {
m.StartFunc()
}
}
func (m mockStepExecutor) Stop() {
if m.StopFunc != nil {
m.StopFunc()
}
}
func (m mockStepExecutor) AddSteps(s *stepStack) {
if m.AddStepsFunc != nil {
m.AddStepsFunc(s)
}
}
func withStepExecutor(executor StepExecutor) Opt {
return func(c *Core) {
c.stepExecutor = executor
}
}

View File

@ -173,17 +173,17 @@ func (_c *IMetaTable_AlterAlias_Call) RunAndReturn(run func(context.Context, mes
return _c
}
// AlterCollection provides a mock function with given fields: ctx, oldColl, newColl, ts, fieldModify
func (_m *IMetaTable) AlterCollection(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts uint64, fieldModify bool) error {
ret := _m.Called(ctx, oldColl, newColl, ts, fieldModify)
// AlterCollection provides a mock function with given fields: ctx, result
func (_m *IMetaTable) AlterCollection(ctx context.Context, result message.BroadcastResult[*messagespb.AlterCollectionMessageHeader, *messagespb.AlterCollectionMessageBody]) error {
ret := _m.Called(ctx, result)
if len(ret) == 0 {
panic("no return value specified for AlterCollection")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *model.Collection, *model.Collection, uint64, bool) error); ok {
r0 = rf(ctx, oldColl, newColl, ts, fieldModify)
if rf, ok := ret.Get(0).(func(context.Context, message.BroadcastResult[*messagespb.AlterCollectionMessageHeader, *messagespb.AlterCollectionMessageBody]) error); ok {
r0 = rf(ctx, result)
} else {
r0 = ret.Error(0)
}
@ -198,17 +198,14 @@ type IMetaTable_AlterCollection_Call struct {
// AlterCollection is a helper method to define mock.On call
// - ctx context.Context
// - oldColl *model.Collection
// - newColl *model.Collection
// - ts uint64
// - fieldModify bool
func (_e *IMetaTable_Expecter) AlterCollection(ctx interface{}, oldColl interface{}, newColl interface{}, ts interface{}, fieldModify interface{}) *IMetaTable_AlterCollection_Call {
return &IMetaTable_AlterCollection_Call{Call: _e.mock.On("AlterCollection", ctx, oldColl, newColl, ts, fieldModify)}
// - result message.BroadcastResult[*messagespb.AlterCollectionMessageHeader,*messagespb.AlterCollectionMessageBody]
func (_e *IMetaTable_Expecter) AlterCollection(ctx interface{}, result interface{}) *IMetaTable_AlterCollection_Call {
return &IMetaTable_AlterCollection_Call{Call: _e.mock.On("AlterCollection", ctx, result)}
}
func (_c *IMetaTable_AlterCollection_Call) Run(run func(ctx context.Context, oldColl *model.Collection, newColl *model.Collection, ts uint64, fieldModify bool)) *IMetaTable_AlterCollection_Call {
func (_c *IMetaTable_AlterCollection_Call) Run(run func(ctx context.Context, result message.BroadcastResult[*messagespb.AlterCollectionMessageHeader, *messagespb.AlterCollectionMessageBody])) *IMetaTable_AlterCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*model.Collection), args[2].(*model.Collection), args[3].(uint64), args[4].(bool))
run(args[0].(context.Context), args[1].(message.BroadcastResult[*messagespb.AlterCollectionMessageHeader, *messagespb.AlterCollectionMessageBody]))
})
return _c
}
@ -218,7 +215,7 @@ func (_c *IMetaTable_AlterCollection_Call) Return(_a0 error) *IMetaTable_AlterCo
return _c
}
func (_c *IMetaTable_AlterCollection_Call) RunAndReturn(run func(context.Context, *model.Collection, *model.Collection, uint64, bool) error) *IMetaTable_AlterCollection_Call {
func (_c *IMetaTable_AlterCollection_Call) RunAndReturn(run func(context.Context, message.BroadcastResult[*messagespb.AlterCollectionMessageHeader, *messagespb.AlterCollectionMessageBody]) error) *IMetaTable_AlterCollection_Call {
_c.Call.Return(run)
return _c
}
@ -570,6 +567,56 @@ func (_c *IMetaTable_CheckIfAliasDroppable_Call) RunAndReturn(run func(context.C
return _c
}
// CheckIfCollectionRenamable provides a mock function with given fields: ctx, dbName, oldName, newDBName, newName
func (_m *IMetaTable) CheckIfCollectionRenamable(ctx context.Context, dbName string, oldName string, newDBName string, newName string) error {
ret := _m.Called(ctx, dbName, oldName, newDBName, newName)
if len(ret) == 0 {
panic("no return value specified for CheckIfCollectionRenamable")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok {
r0 = rf(ctx, dbName, oldName, newDBName, newName)
} else {
r0 = ret.Error(0)
}
return r0
}
// IMetaTable_CheckIfCollectionRenamable_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckIfCollectionRenamable'
type IMetaTable_CheckIfCollectionRenamable_Call struct {
*mock.Call
}
// CheckIfCollectionRenamable is a helper method to define mock.On call
// - ctx context.Context
// - dbName string
// - oldName string
// - newDBName string
// - newName string
func (_e *IMetaTable_Expecter) CheckIfCollectionRenamable(ctx interface{}, dbName interface{}, oldName interface{}, newDBName interface{}, newName interface{}) *IMetaTable_CheckIfCollectionRenamable_Call {
return &IMetaTable_CheckIfCollectionRenamable_Call{Call: _e.mock.On("CheckIfCollectionRenamable", ctx, dbName, oldName, newDBName, newName)}
}
func (_c *IMetaTable_CheckIfCollectionRenamable_Call) Run(run func(ctx context.Context, dbName string, oldName string, newDBName string, newName string)) *IMetaTable_CheckIfCollectionRenamable_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string))
})
return _c
}
func (_c *IMetaTable_CheckIfCollectionRenamable_Call) Return(_a0 error) *IMetaTable_CheckIfCollectionRenamable_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_CheckIfCollectionRenamable_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *IMetaTable_CheckIfCollectionRenamable_Call {
_c.Call.Return(run)
return _c
}
// CheckIfCreateRole provides a mock function with given fields: ctx, req
func (_m *IMetaTable) CheckIfCreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) error {
ret := _m.Called(ctx, req)
@ -3288,57 +3335,6 @@ func (_c *IMetaTable_RemovePartition_Call) RunAndReturn(run func(context.Context
return _c
}
// RenameCollection provides a mock function with given fields: ctx, dbName, oldName, newDBName, newName, ts
func (_m *IMetaTable) RenameCollection(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts uint64) error {
ret := _m.Called(ctx, dbName, oldName, newDBName, newName, ts)
if len(ret) == 0 {
panic("no return value specified for RenameCollection")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, uint64) error); ok {
r0 = rf(ctx, dbName, oldName, newDBName, newName, ts)
} else {
r0 = ret.Error(0)
}
return r0
}
// IMetaTable_RenameCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RenameCollection'
type IMetaTable_RenameCollection_Call struct {
*mock.Call
}
// RenameCollection is a helper method to define mock.On call
// - ctx context.Context
// - dbName string
// - oldName string
// - newDBName string
// - newName string
// - ts uint64
func (_e *IMetaTable_Expecter) RenameCollection(ctx interface{}, dbName interface{}, oldName interface{}, newDBName interface{}, newName interface{}, ts interface{}) *IMetaTable_RenameCollection_Call {
return &IMetaTable_RenameCollection_Call{Call: _e.mock.On("RenameCollection", ctx, dbName, oldName, newDBName, newName, ts)}
}
func (_c *IMetaTable_RenameCollection_Call) Run(run func(ctx context.Context, dbName string, oldName string, newDBName string, newName string, ts uint64)) *IMetaTable_RenameCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(uint64))
})
return _c
}
func (_c *IMetaTable_RenameCollection_Call) Return(_a0 error) *IMetaTable_RenameCollection_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_RenameCollection_Call) RunAndReturn(run func(context.Context, string, string, string, string, uint64) error) *IMetaTable_RenameCollection_Call {
_c.Call.Return(run)
return _c
}
// RestoreRBAC provides a mock function with given fields: ctx, tenant, meta
func (_m *IMetaTable) RestoreRBAC(ctx context.Context, tenant string, meta *milvuspb.RBACMeta) error {
ret := _m.Called(ctx, tenant, meta)

View File

@ -1,70 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
)
type baseRedoTask struct {
syncTodoStep []nestedStep // steps to execute synchronously
asyncTodoStep []nestedStep // steps to execute asynchronously
stepExecutor StepExecutor
}
func newBaseRedoTask(stepExecutor StepExecutor) *baseRedoTask {
return &baseRedoTask{
syncTodoStep: make([]nestedStep, 0),
asyncTodoStep: make([]nestedStep, 0),
stepExecutor: stepExecutor,
}
}
func (b *baseRedoTask) AddSyncStep(step nestedStep) {
b.syncTodoStep = append(b.syncTodoStep, step)
}
func (b *baseRedoTask) AddAsyncStep(step nestedStep) {
b.asyncTodoStep = append(b.asyncTodoStep, step)
}
func (b *baseRedoTask) redoAsyncSteps() {
l := len(b.asyncTodoStep)
steps := make([]nestedStep, 0, l)
for i := l - 1; i >= 0; i-- {
steps = append(steps, b.asyncTodoStep[i])
}
b.asyncTodoStep = nil // make baseRedoTask can be collected.
b.stepExecutor.AddSteps(&stepStack{steps: steps})
}
func (b *baseRedoTask) Execute(ctx context.Context) error {
for i := 0; i < len(b.syncTodoStep); i++ {
todo := b.syncTodoStep[i]
// no children step in sync steps.
if _, err := todo.Execute(ctx); err != nil {
log.Error("failed to execute step", zap.Error(err), zap.String("desc", todo.Desc()))
return err
}
}
go b.redoAsyncSteps()
return nil
}

View File

@ -1,151 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
)
type mockFailStep struct {
baseStep
calledChan chan struct{}
called bool
err error
}
func newMockFailStep() *mockFailStep {
return &mockFailStep{calledChan: make(chan struct{}, 1), called: false}
}
func (m *mockFailStep) Execute(ctx context.Context) ([]nestedStep, error) {
m.called = true
m.calledChan <- struct{}{}
if m.err != nil {
return nil, m.err
}
return nil, errors.New("error mock Execute")
}
type mockNormalStep struct {
nestedStep
calledChan chan struct{}
called bool
}
func newMockNormalStep() *mockNormalStep {
return &mockNormalStep{calledChan: make(chan struct{}, 1), called: false}
}
func (m *mockNormalStep) Execute(ctx context.Context) ([]nestedStep, error) {
m.called = true
m.calledChan <- struct{}{}
return nil, nil
}
func newTestRedoTask() *baseRedoTask {
stepExecutor := newMockStepExecutor()
stepExecutor.AddStepsFunc = func(s *stepStack) {
// no schedule, execute directly.
s.Execute(context.Background())
}
redo := newBaseRedoTask(stepExecutor)
return redo
}
func Test_baseRedoTask_redoAsyncSteps(t *testing.T) {
t.Run("partial error", func(t *testing.T) {
redo := newTestRedoTask()
steps := []nestedStep{newMockNormalStep(), newMockFailStep(), newMockNormalStep()}
for _, step := range steps {
redo.AddAsyncStep(step)
}
redo.redoAsyncSteps()
assert.True(t, steps[0].(*mockNormalStep).called)
assert.False(t, steps[2].(*mockNormalStep).called)
})
t.Run("normal case", func(t *testing.T) {
redo := newTestRedoTask()
n := 10
steps := make([]nestedStep, 0, n)
for i := 0; i < n; i++ {
steps = append(steps, newMockNormalStep())
}
for _, step := range steps {
redo.AddAsyncStep(step)
}
redo.redoAsyncSteps()
for _, step := range steps {
assert.True(t, step.(*mockNormalStep).called)
}
})
}
func Test_baseRedoTask_Execute(t *testing.T) {
t.Run("sync not finished, no async task", func(t *testing.T) {
redo := newTestRedoTask()
syncSteps := []nestedStep{newMockFailStep()}
asyncNum := 10
asyncSteps := make([]nestedStep, 0, asyncNum)
for i := 0; i < asyncNum; i++ {
asyncSteps = append(asyncSteps, newMockNormalStep())
}
for _, step := range asyncSteps {
redo.AddAsyncStep(step)
}
for _, step := range syncSteps {
redo.AddSyncStep(step)
}
err := redo.Execute(context.Background())
assert.Error(t, err)
for _, step := range asyncSteps {
assert.False(t, step.(*mockNormalStep).called)
}
})
// TODO: sync finished, but some async fail.
t.Run("normal case", func(t *testing.T) {
redo := newTestRedoTask()
syncNum := 10
syncSteps := make([]nestedStep, 0, syncNum)
asyncNum := 10
asyncSteps := make([]nestedStep, 0, asyncNum)
for i := 0; i < syncNum; i++ {
syncSteps = append(syncSteps, newMockNormalStep())
}
for i := 0; i < asyncNum; i++ {
asyncSteps = append(asyncSteps, newMockNormalStep())
}
for _, step := range asyncSteps {
redo.AddAsyncStep(step)
}
for _, step := range syncSteps {
redo.AddSyncStep(step)
}
err := redo.Execute(context.Background())
assert.NoError(t, err)
for _, step := range asyncSteps {
<-step.(*mockNormalStep).calledChan
assert.True(t, step.(*mockNormalStep).called)
}
})
}

View File

@ -1,141 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type renameCollectionTask struct {
baseTask
Req *milvuspb.RenameCollectionRequest
}
func (t *renameCollectionTask) Prepare(ctx context.Context) error {
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_RenameCollection); err != nil {
return err
}
if t.Req.GetDbName() == "" {
t.Req.DbName = util.DefaultDBName
}
if t.Req.GetNewDBName() == "" {
t.Req.NewDBName = t.Req.GetDbName()
}
return nil
}
func (t *renameCollectionTask) Execute(ctx context.Context) error {
// Check if renaming across databases with encryption enabled
if t.Req.GetNewDBName() != t.Req.GetDbName() {
if err := t.validateEncryption(ctx); err != nil {
return err
}
}
targetDB := t.Req.GetNewDBName()
// check old collection isn't alias and exists in old db
if t.core.meta.IsAlias(ctx, t.Req.GetDbName(), t.Req.GetOldName()) {
return fmt.Errorf("unsupported use an alias to rename collection, alias:%s", t.Req.GetOldName())
}
collInfo, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetOldName(), typeutil.MaxTimestamp)
if err != nil {
return fmt.Errorf("collection not found in database, collection: %s, database: %s", t.Req.GetOldName(), t.Req.GetDbName())
}
// check old collection doesn't have aliases if renaming databases
aliases := t.core.meta.ListAliasesByID(ctx, collInfo.CollectionID)
if len(aliases) > 0 && targetDB != t.Req.GetDbName() {
return fmt.Errorf("fail to rename collection to different database, must drop all aliases of collection %s before rename", t.Req.GetOldName())
}
// check new collection isn't alias and not exists in new db
if t.core.meta.IsAlias(ctx, targetDB, t.Req.GetNewName()) {
return fmt.Errorf("cannot rename collection to an existing alias: %s", t.Req.GetNewName())
}
_, err = t.core.meta.GetCollectionByName(ctx, targetDB, t.Req.GetNewName(), typeutil.MaxTimestamp)
if err == nil {
return fmt.Errorf("duplicated new collection name %s in database %s with other collection name or alias", t.Req.GetNewName(), targetDB)
}
ts := t.GetTs()
redoTask := newBaseRedoTask(t.core.stepExecutor)
// Step 1: Rename collection in metadata catalog
redoTask.AddSyncStep(&renameCollectionStep{
baseStep: baseStep{core: t.core},
dbName: t.Req.GetDbName(),
oldName: t.Req.GetOldName(),
newDBName: t.Req.GetNewDBName(),
newName: t.Req.GetNewName(),
ts: ts,
})
// Step 2: Expire cache for old collection name
redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: t.core},
dbName: t.Req.GetDbName(),
collectionNames: append(aliases, t.Req.GetOldName()),
collectionID: collInfo.CollectionID,
ts: ts,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_RenameCollection)},
})
return redoTask.Execute(ctx)
}
func (t *renameCollectionTask) validateEncryption(ctx context.Context) error {
// old and new DB names are filled in Prepare, shouldn't be empty here
oldDBName := t.Req.GetDbName()
newDBName := t.Req.GetNewDBName()
originalDB, err := t.core.meta.GetDatabaseByName(ctx, oldDBName, 0)
if err != nil {
return fmt.Errorf("failed to get original database: %w", err)
}
targetDB, err := t.core.meta.GetDatabaseByName(ctx, newDBName, 0)
if err != nil {
return fmt.Errorf("target database %s not found: %w", newDBName, err)
}
// Check if either database has encryption enabled
if hookutil.IsDBEncryptionEnabled(originalDB.Properties) || hookutil.IsDBEncryptionEnabled(targetDB.Properties) {
return fmt.Errorf("deny to change collection databases due to at least one database enabled encryption, original DB: %s, target DB: %s", oldDBName, newDBName)
}
return nil
}
func (t *renameCollectionTask) GetLockerKey() LockerKey {
return NewLockerKeyChain(
NewClusterLockerKey(true),
)
}

View File

@ -1,783 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
)
func Test_renameCollectionTask_Prepare(t *testing.T) {
t.Run("invalid msg type", func(t *testing.T) {
task := &renameCollectionTask{
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Undefined,
},
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("normal case same database", func(t *testing.T) {
task := &renameCollectionTask{
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
DbName: "db1",
OldName: "old_collection",
NewName: "new_collection",
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
})
t.Run("cross database with encryption enabled", func(t *testing.T) {
oldDB := &model.Database{
Name: "db1",
ID: 1,
Properties: []*commonpb.KeyValuePair{
{Key: "cipher.enabled", Value: "true"},
},
}
newDB := &model.Database{
Name: "db2",
ID: 2,
}
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetDatabaseByName",
mock.Anything,
"db1",
mock.AnythingOfType("uint64"),
).Return(oldDB, nil)
meta.On("GetDatabaseByName",
mock.Anything,
"db2",
mock.AnythingOfType("uint64"),
).Return(newDB, nil)
core := newTestCore(withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
DbName: "db1",
OldName: "old_collection",
NewDBName: "db2",
NewName: "new_collection",
},
}
// First call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
// Then call Execute where encryption validation happens
err = task.Execute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "deny to change collection databases due to at least one database enabled encryption")
})
t.Run("cross database without encryption", func(t *testing.T) {
oldDB := &model.Database{
Name: "db1",
ID: 1,
}
newDB := &model.Database{
Name: "db2",
ID: 2,
}
collectionID := int64(111)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetDatabaseByName",
mock.Anything,
"db1",
mock.AnythingOfType("uint64"),
).Return(oldDB, nil)
meta.On("GetDatabaseByName",
mock.Anything,
"db2",
mock.AnythingOfType("uint64"),
).Return(newDB, nil)
// Mock additional methods called in Execute after encryption check
meta.On("IsAlias",
mock.Anything,
"db1",
"old_collection",
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
"db1",
"old_collection",
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: "old_collection",
}, nil)
meta.On("ListAliasesByID",
mock.Anything,
collectionID,
).Return([]string{})
meta.On("IsAlias",
mock.Anything,
"db2",
"new_collection",
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
"db2",
"new_collection",
mock.AnythingOfType("uint64"),
).Return(nil, errors.New("not found"))
meta.On("RenameCollection",
mock.Anything,
"db1",
"old_collection",
"db2",
"new_collection",
mock.Anything,
).Return(nil)
core := newTestCore(withValidProxyManager(), withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
DbName: "db1",
OldName: "old_collection",
NewDBName: "db2",
NewName: "new_collection",
},
}
// First call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
// Then call Execute - should pass encryption check and proceed
err = task.Execute(context.Background())
assert.NoError(t, err)
})
}
func Test_renameCollectionTask_Execute(t *testing.T) {
t.Run("collection not found", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
oldName,
mock.AnythingOfType("uint64"),
).Return(nil, errors.New("collection not found"))
core := newTestCore(withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
OldName: oldName,
NewName: newName,
},
}
// Call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("rename step failed", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
collectionID := int64(111)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("ListAliasesByID",
mock.Anything,
collectionID,
).Return([]string{})
meta.On("IsAlias",
mock.Anything,
mock.Anything,
newName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
newName,
mock.AnythingOfType("uint64"),
).Return(nil, errors.New("not found"))
meta.On("RenameCollection",
mock.Anything,
mock.Anything,
oldName,
mock.Anything,
newName,
mock.Anything,
).Return(errors.New("failed to rename collection"))
core := newTestCore(withValidProxyManager(), withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
OldName: oldName,
NewName: newName,
},
}
// Call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("expire cache failed", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
collectionID := int64(111)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("ListAliasesByID",
mock.Anything,
collectionID,
).Return([]string{})
meta.On("IsAlias",
mock.Anything,
mock.Anything,
newName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
newName,
mock.AnythingOfType("uint64"),
).Return(nil, errors.New("not found"))
meta.On("RenameCollection",
mock.Anything,
mock.Anything,
oldName,
mock.Anything,
newName,
mock.Anything,
).Return(nil)
core := newTestCore(withInvalidProxyManager(), withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
OldName: oldName,
NewName: newName,
},
}
// Call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("rename with aliases within same database", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
collectionID := int64(111)
aliases := []string{"alias1", "alias2"}
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
newName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
newName,
mock.AnythingOfType("uint64"),
).Return(nil, errors.New("not found"))
meta.On("ListAliasesByID",
mock.Anything,
mock.Anything,
).Return(aliases)
meta.On("RenameCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
core := newTestCore(withValidProxyManager(), withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
DbName: "db1",
OldName: oldName,
NewName: newName,
},
}
// Call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.NoError(t, err)
})
t.Run("rename with aliases across databases should fail", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
collectionID := int64(111)
aliases := []string{"alias1", "alias2"}
oldDB := &model.Database{
Name: "db1",
ID: 1,
}
newDB := &model.Database{
Name: "db2",
ID: 2,
}
meta := mockrootcoord.NewIMetaTable(t)
// Mock for encryption check
meta.On("GetDatabaseByName",
mock.Anything,
"db1",
mock.AnythingOfType("uint64"),
).Return(oldDB, nil)
meta.On("GetDatabaseByName",
mock.Anything,
"db2",
mock.AnythingOfType("uint64"),
).Return(newDB, nil)
// Mock for collection checks
meta.On("IsAlias",
mock.Anything,
"db1",
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
"db1",
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("ListAliasesByID",
mock.Anything,
collectionID,
).Return(aliases)
core := newTestCore(withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
DbName: "db1",
OldName: oldName,
NewDBName: "db2",
NewName: newName,
},
}
// Call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "must drop all aliases")
})
t.Run("rename using alias as old name should fail", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(true)
core := newTestCore(withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
OldName: oldName,
NewName: newName,
},
}
// Call Prepare to set default values
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported use an alias to rename collection")
})
t.Run("rename to existing alias should fail", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
collectionID := int64(111)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("ListAliasesByID",
mock.Anything,
collectionID,
).Return([]string{})
meta.On("IsAlias",
mock.Anything,
mock.Anything,
newName,
).Return(true)
core := newTestCore(withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
OldName: oldName,
NewName: newName,
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot rename collection to an existing alias")
})
t.Run("rename to existing collection should fail", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
collectionID := int64(111)
existingCollID := int64(222)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("ListAliasesByID",
mock.Anything,
collectionID,
).Return([]string{})
meta.On("IsAlias",
mock.Anything,
mock.Anything,
newName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
newName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: existingCollID,
Name: newName,
}, nil)
core := newTestCore(withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
OldName: oldName,
NewName: newName,
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "duplicated new collection name")
})
t.Run("rename across databases without aliases", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
oldDB := "db1"
newDB := "db2"
collectionID := int64(111)
meta := mockrootcoord.NewIMetaTable(t)
// Mock for encryption check
meta.On("GetDatabaseByName",
mock.Anything,
oldDB,
mock.AnythingOfType("uint64"),
).Return(&model.Database{
ID: 1,
Name: oldDB,
Properties: []*commonpb.KeyValuePair{},
}, nil)
meta.On("GetDatabaseByName",
mock.Anything,
newDB,
mock.AnythingOfType("uint64"),
).Return(&model.Database{
ID: 2,
Name: newDB,
Properties: []*commonpb.KeyValuePair{},
}, nil)
meta.On("IsAlias",
mock.Anything,
oldDB,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
oldDB,
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("IsAlias",
mock.Anything,
newDB,
newName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
newDB,
newName,
mock.AnythingOfType("uint64"),
).Return(nil, errors.New("not found"))
meta.On("ListAliasesByID",
mock.Anything,
mock.Anything,
).Return([]string{})
meta.On("RenameCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
core := newTestCore(withValidProxyManager(), withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
DbName: oldDB,
OldName: oldName,
NewDBName: newDB,
NewName: newName,
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
})
t.Run("normal case", func(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
collectionID := int64(111)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
oldName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
oldName,
mock.AnythingOfType("uint64"),
).Return(&model.Collection{
CollectionID: collectionID,
Name: oldName,
}, nil)
meta.On("IsAlias",
mock.Anything,
mock.Anything,
newName,
).Return(false)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
newName,
mock.AnythingOfType("uint64"),
).Return(nil, errors.New("not found"))
meta.On("ListAliasesByID",
mock.Anything,
mock.Anything,
).Return([]string{})
meta.On("RenameCollection",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
core := newTestCore(withValidProxyManager(), withMeta(meta))
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
OldName: oldName,
NewName: newName,
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
err = task.Execute(context.Background())
assert.NoError(t, err)
})
}
func Test_renameCollectionTask_GetLockerKey(t *testing.T) {
oldName := funcutil.GenRandomStr()
newName := funcutil.GenRandomStr()
core := newTestCore()
task := &renameCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.RenameCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RenameCollection,
},
DbName: "db1",
OldName: oldName,
NewName: newName,
},
}
key := task.GetLockerKey()
assert.NotNil(t, key)
}

View File

@ -100,7 +100,6 @@ type Core struct {
scheduler IScheduler
broker Broker
ddlTsLockManager DdlTsLockManager
stepExecutor StepExecutor
metaKVCreator metaKVCreator
@ -441,7 +440,6 @@ func (c *Core) initInternal() error {
c.broker = newServerBroker(c)
c.ddlTsLockManager = newDdlTsLockManager(c.tsoAllocator)
c.stepExecutor = newBgStepExecutor(c.ctx)
c.proxyWatcher = proxyutil.NewProxyWatcher(
c.etcdCli,
@ -648,7 +646,6 @@ func (c *Core) startInternal() error {
}
c.scheduler.Start()
c.stepExecutor.Start()
go func() {
// refresh rbac cache
if err := retry.Do(c.ctx, func() error {
@ -688,13 +685,6 @@ func (c *Core) Start() error {
return err
}
func (c *Core) stopExecutor() {
if c.stepExecutor != nil {
c.stepExecutor.Stop()
log.Ctx(c.ctx).Info("stop rootcoord executor")
}
}
func (c *Core) stopScheduler() {
if c.scheduler != nil {
c.scheduler.Stop()
@ -726,7 +716,6 @@ func (c *Core) Stop() error {
if c.tombstoneSweeper != nil {
c.tombstoneSweeper.Close()
}
c.stopExecutor()
c.stopScheduler()
if c.proxyWatcher != nil {
@ -930,45 +919,20 @@ func (c *Core) AddCollectionField(ctx context.Context, in *milvuspb.AddCollectio
metrics.RootCoordDDLReqCounter.WithLabelValues("AddCollectionField", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("AddCollectionField")
log.Ctx(ctx).Info("received request to add field",
log := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.String("name", in.GetCollectionName()),
zap.String("role", typeutil.RootCoordRole))
t := &addCollectionFieldTask{
baseTask: newBaseTask(ctx, c),
Req: in,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Ctx(ctx).Info("failed to enqueue request to add field",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()))
metrics.RootCoordDDLReqCounter.WithLabelValues("AddCollectionField", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Ctx(ctx).Info("failed to add field",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()),
zap.Uint64("ts", t.GetTs()))
zap.String("collectionName", in.GetCollectionName()))
log.Info("received request to add collection field")
if err := c.broadcastAlterCollectionForAddField(ctx, in); err != nil {
log.Info("failed to add collection field", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("AddCollectionField", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("AddCollectionField", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("AddCollectionField").Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.RootCoordDDLReqLatencyInQueue.WithLabelValues("AddCollectionField").Observe(float64(t.queueDur.Milliseconds()))
log.Ctx(ctx).Info("done to add field",
zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetCollectionName()),
zap.Uint64("ts", t.GetTs()))
log.Info("done to add collection field")
return merr.Success(), nil
}
@ -1292,58 +1256,28 @@ func (c *Core) AlterCollection(ctx context.Context, in *milvuspb.AlterCollection
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollection", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("AlterCollection")
log := log.Ctx(ctx).With(
zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetCollectionName()),
)
log.Info("received request to alter collection",
log := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.String("collectionName", in.GetCollectionName()),
zap.Any("props", in.Properties),
zap.Any("delete_keys", in.DeleteKeys),
zap.Strings("deleteKeys", in.DeleteKeys),
)
log.Info("received request to alter collection")
var t task
if ok, value, err := common.IsEnableDynamicSchema(in.GetProperties()); ok {
if err != nil {
log.Warn("failed to check dynamic schema prop kv", zap.Error(err))
return merr.Status(err), nil
if err := c.broadcastAlterCollectionForAlterCollection(ctx, in); err != nil {
if errors.Is(err, errIgnoredAlterCollection) {
log.Info("alter collection make no changes, ignore it")
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollection", metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
log.Info("found update dynamic schema prop kv")
t = &alterDynamicFieldTask{
baseTask: newBaseTask(ctx, c),
Req: in,
targetValue: value,
}
} else {
t = &alterCollectionTask{
baseTask: newBaseTask(ctx, c),
Req: in,
}
}
if err := c.scheduler.AddTask(t); err != nil {
log.Warn("failed to enqueue request to alter collection",
zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Warn("failed to alter collection",
zap.Error(err),
zap.Uint64("ts", t.GetTs()))
log.Warn("failed to alter collection", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollection", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("AlterCollection").Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.RootCoordDDLReqLatencyInQueue.WithLabelValues("AlterCollection").Observe(float64(t.GetDurationInQueue().Milliseconds()))
log.Info("done to alter collection",
zap.Uint64("ts", t.GetTs()))
log.Info("done to alter collection")
return merr.Success(), nil
}
@ -1355,47 +1289,29 @@ func (c *Core) AlterCollectionField(ctx context.Context, in *milvuspb.AlterColle
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollectionField", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("AlterCollectionField")
log.Ctx(ctx).Info("received request to alter collection field",
zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetCollectionName()),
log := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.String("collectionName", in.GetCollectionName()),
zap.String("fieldName", in.GetFieldName()),
zap.Any("props", in.Properties),
zap.Strings("deleteKeys", in.DeleteKeys),
)
log.Info("received request to alter collection field")
t := &alterCollectionFieldTask{
baseTask: newBaseTask(ctx, c),
Req: in,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Warn("failed to enqueue request to alter collection field",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()),
zap.String("fieldName", in.GetFieldName()))
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollectionField", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Warn("failed to alter collection",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()),
zap.Uint64("ts", t.GetTs()))
if err := c.broadcastAlterCollectionV2ForAlterCollectionField(ctx, in); err != nil {
if errors.Is(err, errIgnoredAlterCollection) {
log.Info("alter collection field make no changes, ignore it")
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollectionField", metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
log.Warn("failed to alter collection field", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollectionField", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("AlterCollectionField", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("AlterCollectionField").Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Info("done to alter collection field",
zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetCollectionName()),
zap.String("fieldName", in.GetFieldName()))
log.Info("done to alter collection field")
return merr.Success(), nil
}
@ -1405,32 +1321,29 @@ func (c *Core) AlterDatabase(ctx context.Context, in *rootcoordpb.AlterDatabaseR
}
method := "AlterDatabase"
metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder(method)
log.Ctx(ctx).Info("received request to alter database",
zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetDbName()),
zap.Any("props", in.Properties))
log := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.Any("props", in.Properties),
zap.Strings("deleteKeys", in.DeleteKeys))
log.Info("received request to alter database")
if err := c.broadcastAlterDatabase(ctx, in); err != nil {
log.Warn("failed to alter database",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetDbName()))
if errors.Is(err, errIgnoredAlterDatabase) {
log.Info("alter database make no changes, ignore it")
metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
log.Warn("failed to alter database", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds()))
// metrics.RootCoordDDLReqLatencyInQueue.WithLabelValues(method).Observe(float64(t.queueDur.Milliseconds()))
log.Ctx(ctx).Info("done to alter database",
zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetDbName()))
log.Info("done to alter database")
return merr.Success(), nil
}
@ -2668,35 +2581,29 @@ func (c *Core) RenameCollection(ctx context.Context, req *milvuspb.RenameCollect
return merr.Status(err), nil
}
log := log.Ctx(ctx).With(zap.String("oldCollectionName", req.GetOldName()),
zap.String("newCollectionName", req.GetNewName()),
log := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("oldDbName", req.GetDbName()),
zap.String("newDbName", req.GetNewDBName()))
zap.String("newDbName", req.GetNewDBName()),
zap.String("oldCollectionName", req.GetOldName()),
zap.String("newCollectionName", req.GetNewName()))
log.Info("received request to rename collection")
metrics.RootCoordDDLReqCounter.WithLabelValues("RenameCollection", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("RenameCollection")
t := &renameCollectionTask{
baseTask: newBaseTask(ctx, c),
Req: req,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Warn("failed to enqueue request to rename collection", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("RenameCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Warn("failed to rename collection", zap.Uint64("ts", t.GetTs()), zap.Error(err))
if err := c.broadcastAlterCollectionForRenameCollection(ctx, req); err != nil {
if errors.Is(err, errIgnoredAlterCollection) {
log.Info("rename collection ignored, collection already uses the new name")
return merr.Success(), nil
}
log.Warn("failed to rename collection", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("RenameCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("RenameCollection", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("RenameCollection").Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Info("done to rename collection", zap.Uint64("ts", t.GetTs()))
log.Info("done to rename collection")
return merr.Success(), nil
}

View File

@ -48,7 +48,7 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
@ -121,6 +121,7 @@ func initStreamingSystemAndCore(t *testing.T) *Core {
}
}
retry.Do(context.Background(), func() error {
log.Info("broadcast message", log.FieldMessage(msg))
return registry.CallMessageAckCallback(context.Background(), msg, results)
}, retry.AttemptAlways())
return &types.BroadcastAppendResult{}, nil
@ -131,6 +132,7 @@ func initStreamingSystemAndCore(t *testing.T) *Core {
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().Close().Return().Maybe()
broadcast.ResetBroadcaster()
broadcast.Register(mb)
@ -144,6 +146,7 @@ func initStreamingSystemAndCore(t *testing.T) *Core {
}
return vchannels, nil
}).Maybe()
b.EXPECT().WaitUntilWALbasedDDLReady(mock.Anything).Return(nil).Maybe()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, callback balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
@ -893,36 +896,6 @@ func TestRootCoord_RenameCollection(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("add task failed", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withInvalidScheduler())
ctx := context.Background()
resp, err := c.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("execute task failed", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withTaskFailScheduler())
ctx := context.Background()
resp, err := c.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("run ok", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.RenameCollection(ctx, &milvuspb.RenameCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
}
func TestRootCoord_ListPolicy(t *testing.T) {
@ -1291,62 +1264,6 @@ func TestRootCoord_AlterCollection(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("add task failed", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withInvalidScheduler())
ctx := context.Background()
resp, err := c.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("execute task failed", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withTaskFailScheduler())
ctx := context.Background()
resp, err := c.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("run ok", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("set_dynamic_field_bad_request", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
Properties: []*commonpb.KeyValuePair{
{Key: common.EnableDynamicSchemaKey, Value: "abc"},
},
})
assert.Error(t, merr.CheckRPCCall(resp, err))
})
t.Run("set_dynamic_field_ok", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.AlterCollection(ctx, &milvuspb.AlterCollectionRequest{
Properties: []*commonpb.KeyValuePair{
{Key: common.EnableDynamicSchemaKey, Value: "true"},
},
})
assert.NoError(t, merr.CheckRPCCall(resp, err))
})
}
func TestRootCoord_CheckHealth(t *testing.T) {

View File

@ -1,369 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"time"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
type stepPriority int
const (
stepPriorityLow = 0
stepPriorityNormal = 1
stepPriorityImportant = 10
stepPriorityUrgent = 1000
)
type nestedStep interface {
Execute(ctx context.Context) ([]nestedStep, error)
Desc() string
Weight() stepPriority
}
type baseStep struct {
core *Core
}
func (s baseStep) Desc() string {
return ""
}
func (s baseStep) Weight() stepPriority {
return stepPriorityLow
}
type cleanupMetricsStep struct {
baseStep
dbName string
collectionName string
}
func (s *cleanupMetricsStep) Execute(ctx context.Context) ([]nestedStep, error) {
metrics.CleanupRootCoordCollectionMetrics(s.dbName, s.collectionName)
return nil, nil
}
func (s *cleanupMetricsStep) Desc() string {
return fmt.Sprintf("change collection state, db: %s, collectionstate: %s",
s.dbName, s.collectionName)
}
type expireCacheStep struct {
baseStep
dbName string
collectionNames []string
collectionID UniqueID
partitionName string
ts Timestamp
opts []proxyutil.ExpireCacheOpt
}
func (s *expireCacheStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.ExpireMetaCache(ctx, s.dbName, s.collectionNames, s.collectionID, s.partitionName, s.ts, s.opts...)
return nil, err
}
func (s *expireCacheStep) Desc() string {
return fmt.Sprintf("expire cache, collection id: %d, collection names: %s, ts: %d",
s.collectionID, s.collectionNames, s.ts)
}
type releaseCollectionStep struct {
baseStep
collectionID UniqueID
}
func (s *releaseCollectionStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.ReleaseCollection(ctx, s.collectionID)
log.Ctx(ctx).Info("release collection done", zap.Int64("collectionID", s.collectionID))
return nil, err
}
func (s *releaseCollectionStep) Desc() string {
return fmt.Sprintf("release collection: %d", s.collectionID)
}
func (s *releaseCollectionStep) Weight() stepPriority {
return stepPriorityUrgent
}
type releasePartitionsStep struct {
baseStep
collectionID UniqueID
partitionIDs []UniqueID
}
func (s *releasePartitionsStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.ReleasePartitions(ctx, s.collectionID, s.partitionIDs...)
return nil, err
}
func (s *releasePartitionsStep) Desc() string {
return fmt.Sprintf("release partitions, collectionID=%d, partitionIDs=%v", s.collectionID, s.partitionIDs)
}
func (s *releasePartitionsStep) Weight() stepPriority {
return stepPriorityUrgent
}
type dropIndexStep struct {
baseStep
collID UniqueID
partIDs []UniqueID
}
func (s *dropIndexStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.DropCollectionIndex(ctx, s.collID, s.partIDs)
return nil, err
}
func (s *dropIndexStep) Desc() string {
return fmt.Sprintf("drop collection index: %d", s.collID)
}
func (s *dropIndexStep) Weight() stepPriority {
return stepPriorityNormal
}
type nullStep struct{}
func (s *nullStep) Execute(ctx context.Context) ([]nestedStep, error) {
return nil, nil
}
func (s *nullStep) Desc() string {
return ""
}
func (s *nullStep) Weight() stepPriority {
return stepPriorityLow
}
type AlterCollectionStep struct {
baseStep
oldColl *model.Collection
newColl *model.Collection
ts Timestamp
fieldModify bool
}
func (a *AlterCollectionStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := a.core.meta.AlterCollection(ctx, a.oldColl, a.newColl, a.ts, a.fieldModify)
return nil, err
}
func (a *AlterCollectionStep) Desc() string {
return fmt.Sprintf("alter collection, collectionID: %d, ts: %d", a.oldColl.CollectionID, a.ts)
}
type BroadcastAlteredCollectionStep struct {
baseStep
req *milvuspb.AlterCollectionRequest
core *Core
}
func (b *BroadcastAlteredCollectionStep) Execute(ctx context.Context) ([]nestedStep, error) {
// TODO: support online schema change mechanism
// It only broadcast collection properties to DataCoord service
err := b.core.broker.BroadcastAlteredCollection(ctx, b.req)
return nil, err
}
func (b *BroadcastAlteredCollectionStep) Desc() string {
return fmt.Sprintf("broadcast altered collection, collectionID: %d", b.req.CollectionID)
}
type AddCollectionFieldStep struct {
baseStep
oldColl *model.Collection
updatedCollection *model.Collection
newField *model.Field
ts Timestamp
}
func (a *AddCollectionFieldStep) Execute(ctx context.Context) ([]nestedStep, error) {
// newColl := a.oldColl.Clone()
// newColl.Fields = append(newColl.Fields, a.newField)
err := a.core.meta.AlterCollection(ctx, a.oldColl, a.updatedCollection, a.updatedCollection.UpdateTimestamp, true)
log.Ctx(ctx).Info("add field done", zap.Int64("collectionID", a.oldColl.CollectionID), zap.Any("new field", a.newField))
return nil, err
}
func (a *AddCollectionFieldStep) Desc() string {
return fmt.Sprintf("add field, collectionID: %d, fieldID: %d, ts: %d", a.oldColl.CollectionID, a.newField.FieldID, a.ts)
}
type WriteSchemaChangeWALStep struct {
baseStep
collection *model.Collection
ts Timestamp
}
func (s *WriteSchemaChangeWALStep) Execute(ctx context.Context) ([]nestedStep, error) {
vchannels := s.collection.VirtualChannelNames
schema := &schemapb.CollectionSchema{
Name: s.collection.Name,
Description: s.collection.Description,
AutoID: s.collection.AutoID,
Fields: model.MarshalFieldModels(s.collection.Fields),
StructArrayFields: model.MarshalStructArrayFieldModels(s.collection.StructArrayFields),
Functions: model.MarshalFunctionModels(s.collection.Functions),
EnableDynamicField: s.collection.EnableDynamicField,
Properties: s.collection.Properties,
}
schemaMsg, err := message.NewSchemaChangeMessageBuilderV2().
WithBroadcast(vchannels).
WithHeader(&message.SchemaChangeMessageHeader{
CollectionId: s.collection.CollectionID,
}).
WithBody(&message.SchemaChangeMessageBody{
Schema: schema,
}).BuildBroadcast()
if err != nil {
return nil, err
}
resp, err := streaming.WAL().Broadcast().Append(ctx, schemaMsg)
if err != nil {
return nil, err
}
// use broadcast max msg timestamp as update timestamp here
s.collection.UpdateTimestamp = lo.Max(lo.Map(vchannels, func(channelName string, _ int) uint64 {
return resp.GetAppendResult(channelName).TimeTick
}))
log.Ctx(ctx).Info(
"broadcast schema change success",
zap.Uint64("broadcastID", resp.BroadcastID),
zap.Uint64("WALUpdateTimestamp", s.collection.UpdateTimestamp),
zap.Any("appendResults", resp.AppendResults),
)
return nil, nil
}
func (s *WriteSchemaChangeWALStep) Desc() string {
return fmt.Sprintf("write schema change WALcollectionID: %d, ts: %d", s.collection.CollectionID, s.ts)
}
type renameCollectionStep struct {
baseStep
dbName string
oldName string
newDBName string
newName string
ts Timestamp
}
func (s *renameCollectionStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.RenameCollection(ctx, s.dbName, s.oldName, s.newDBName, s.newName, s.ts)
return nil, err
}
func (s *renameCollectionStep) Desc() string {
return fmt.Sprintf("rename collection from %s.%s to %s.%s, ts: %d",
s.dbName, s.oldName, s.newDBName, s.newName, s.ts)
}
var (
confirmGCInterval = time.Minute * 20
allPartition UniqueID = common.AllPartitionsID
)
type confirmGCStep struct {
baseStep
collectionID UniqueID
partitionID UniqueID
lastScheduledTime time.Time
}
func newConfirmGCStep(core *Core, collectionID, partitionID UniqueID) *confirmGCStep {
return &confirmGCStep{
baseStep: baseStep{core: core},
collectionID: collectionID,
partitionID: partitionID,
lastScheduledTime: time.Now(),
}
}
func (b *confirmGCStep) Execute(ctx context.Context) ([]nestedStep, error) {
if time.Since(b.lastScheduledTime) < confirmGCInterval {
return nil, fmt.Errorf("wait for reschedule to confirm GC, collection: %d, partition: %d, last scheduled time: %s, now: %s",
b.collectionID, b.partitionID, b.lastScheduledTime.String(), time.Now().String())
}
finished := b.core.broker.GcConfirm(ctx, b.collectionID, b.partitionID)
if finished {
return nil, nil
}
b.lastScheduledTime = time.Now()
return nil, fmt.Errorf("GC is not finished, collection: %d, partition: %d, last scheduled time: %s, now: %s",
b.collectionID, b.partitionID, b.lastScheduledTime.String(), time.Now().String())
}
func (b *confirmGCStep) Desc() string {
return fmt.Sprintf("wait for GC finished, collection: %d, partition: %d, last scheduled time: %s, now: %s",
b.collectionID, b.partitionID, b.lastScheduledTime.String(), time.Now().String())
}
func (b *confirmGCStep) Weight() stepPriority {
return stepPriorityLow
}
type simpleStep struct {
desc string
weight stepPriority
executeFunc func(ctx context.Context) ([]nestedStep, error)
}
func NewSimpleStep(desc string, executeFunc func(ctx context.Context) ([]nestedStep, error)) nestedStep {
return &simpleStep{
desc: desc,
weight: stepPriorityNormal,
executeFunc: executeFunc,
}
}
func (s *simpleStep) Execute(ctx context.Context) ([]nestedStep, error) {
return s.executeFunc(ctx)
}
func (s *simpleStep) Desc() string {
return s.desc
}
func (s *simpleStep) Weight() stepPriority {
return s.weight
}

View File

@ -1,258 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"sort"
"sync"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
)
const (
defaultBgExecutingParallel = 4
defaultBgExecutingInterval = time.Second
)
type StepExecutor interface {
Start()
Stop()
AddSteps(s *stepStack)
}
type stepStack struct {
steps []nestedStep
}
func (s *stepStack) totalPriority() int {
total := 0
for _, step := range s.steps {
total += int(step.Weight())
}
return total
}
func (s *stepStack) Execute(ctx context.Context) *stepStack {
steps := s.steps
for len(steps) > 0 {
l := len(steps)
todo := steps[l-1]
childSteps, err := todo.Execute(ctx)
// TODO: maybe a interface `step.LogOnError` is better.
_, isConfirmGCStep := todo.(*confirmGCStep)
skipLog := isConfirmGCStep
if !retry.IsRecoverable(err) {
if !skipLog {
log.Ctx(ctx).Warn("failed to execute step, not able to reschedule", zap.Error(err), zap.String("step", todo.Desc()))
}
return nil
}
if err != nil {
s.steps = nil // let's can be collected.
if !skipLog {
log.Ctx(ctx).Warn("failed to execute step, wait for reschedule", zap.Error(err), zap.String("step", todo.Desc()))
}
return &stepStack{steps: steps}
}
// this step is done.
steps = steps[:l-1]
steps = append(steps, childSteps...)
}
// everything is done.
return nil
}
type selectStepPolicy func(map[*stepStack]struct{}) []*stepStack
func randomSelect(parallel int, m map[*stepStack]struct{}) []*stepStack {
if parallel <= 0 {
parallel = defaultBgExecutingParallel
}
res := make([]*stepStack, 0, parallel)
for s := range m {
if len(res) >= parallel {
break
}
res = append(res, s)
}
return res
}
func randomSelectPolicy(parallel int) selectStepPolicy {
return func(m map[*stepStack]struct{}) []*stepStack {
return randomSelect(parallel, m)
}
}
func selectByPriority(parallel int, m map[*stepStack]struct{}) []*stepStack {
h := make([]*stepStack, 0, len(m))
for k := range m {
h = append(h, k)
}
sort.Slice(h, func(i, j int) bool {
return h[i].totalPriority() > h[j].totalPriority()
})
if len(h) <= parallel {
return h
}
return h[:parallel]
}
func selectByPriorityPolicy(parallel int) selectStepPolicy {
return func(m map[*stepStack]struct{}) []*stepStack {
return selectByPriority(parallel, m)
}
}
func defaultSelectPolicy() selectStepPolicy {
return selectByPriorityPolicy(defaultBgExecutingParallel)
}
type bgOpt func(*bgStepExecutor)
func withSelectStepPolicy(policy selectStepPolicy) bgOpt {
return func(bg *bgStepExecutor) {
bg.selector = policy
}
}
func withBgInterval(interval time.Duration) bgOpt {
return func(bg *bgStepExecutor) {
bg.interval = interval
}
}
type bgStepExecutor struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
bufferedSteps map[*stepStack]struct{}
selector selectStepPolicy
mu sync.Mutex
notifyChan chan struct{}
interval time.Duration
}
func newBgStepExecutor(ctx context.Context, opts ...bgOpt) *bgStepExecutor {
ctx1, cancel := context.WithCancel(ctx)
bg := &bgStepExecutor{
ctx: ctx1,
cancel: cancel,
wg: sync.WaitGroup{},
bufferedSteps: make(map[*stepStack]struct{}),
selector: defaultSelectPolicy(),
mu: sync.Mutex{},
notifyChan: make(chan struct{}, 1),
interval: defaultBgExecutingInterval,
}
for _, opt := range opts {
opt(bg)
}
return bg
}
func (bg *bgStepExecutor) Start() {
bg.wg.Add(1)
go bg.scheduleLoop()
}
func (bg *bgStepExecutor) Stop() {
bg.cancel()
bg.wg.Wait()
}
func (bg *bgStepExecutor) AddSteps(s *stepStack) {
bg.addStepsInternal(s)
bg.notify()
}
func (bg *bgStepExecutor) process(steps []*stepStack) {
wg := sync.WaitGroup{}
for i := range steps {
s := steps[i]
if s == nil {
continue
}
wg.Add(1)
go func() {
defer wg.Done()
child := s.Execute(bg.ctx)
if child != nil {
// don't notify, wait for reschedule.
bg.addStepsInternal(child)
}
}()
}
wg.Wait()
}
func (bg *bgStepExecutor) schedule() {
bg.mu.Lock()
selected := bg.selector(bg.bufferedSteps)
for _, s := range selected {
bg.unlockRemoveSteps(s)
}
bg.mu.Unlock()
bg.process(selected)
}
func (bg *bgStepExecutor) scheduleLoop() {
defer bg.wg.Done()
ticker := time.NewTicker(bg.interval)
defer ticker.Stop()
for {
select {
case <-bg.ctx.Done():
return
case <-bg.notifyChan:
bg.schedule()
case <-ticker.C:
bg.schedule()
}
}
}
func (bg *bgStepExecutor) addStepsInternal(s *stepStack) {
bg.mu.Lock()
bg.unlockAddSteps(s)
bg.mu.Unlock()
}
func (bg *bgStepExecutor) unlockAddSteps(s *stepStack) {
bg.bufferedSteps[s] = struct{}{}
}
func (bg *bgStepExecutor) unlockRemoveSteps(s *stepStack) {
delete(bg.bufferedSteps, s)
}
func (bg *bgStepExecutor) notify() {
select {
case bg.notifyChan <- struct{}{}:
default:
}
}

View File

@ -1,229 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"math/rand"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
)
type mockChildStep struct{}
func (m *mockChildStep) Execute(ctx context.Context) ([]nestedStep, error) {
return nil, nil
}
func (m *mockChildStep) Desc() string {
return "mock child step"
}
func (m *mockChildStep) Weight() stepPriority {
return stepPriorityLow
}
func newMockChildStep() *mockChildStep {
return &mockChildStep{}
}
type mockStepWithChild struct{}
func (m *mockStepWithChild) Execute(ctx context.Context) ([]nestedStep, error) {
return []nestedStep{newMockChildStep()}, nil
}
func (m *mockStepWithChild) Desc() string {
return "mock step with child"
}
func (m *mockStepWithChild) Weight() stepPriority {
return stepPriorityLow
}
func newMockStepWithChild() *mockStepWithChild {
return &mockStepWithChild{}
}
func Test_stepStack_Execute(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
steps := []nestedStep{
newMockStepWithChild(),
newMockChildStep(),
}
s := &stepStack{steps: steps}
unfinished := s.Execute(context.Background())
assert.Nil(t, unfinished)
})
t.Run("error case", func(t *testing.T) {
steps := []nestedStep{
newMockNormalStep(),
newMockNormalStep(),
newMockFailStep(),
newMockNormalStep(),
}
s := &stepStack{steps: steps}
unfinished := s.Execute(context.Background())
assert.Equal(t, 3, len(unfinished.steps))
})
t.Run("Unrecoverable", func(t *testing.T) {
failStep := newMockFailStep()
failStep.err = retry.Unrecoverable(errors.New("error mock Execute"))
steps := []nestedStep{
failStep,
}
s := &stepStack{steps: steps}
unfinished := s.Execute(context.Background())
assert.Nil(t, unfinished)
})
}
func Test_randomSelect(t *testing.T) {
s0 := &stepStack{steps: []nestedStep{}}
s1 := &stepStack{steps: []nestedStep{
newMockNormalStep(),
}}
s2 := &stepStack{steps: []nestedStep{
newMockNormalStep(),
newMockNormalStep(),
}}
s3 := &stepStack{steps: []nestedStep{
newMockNormalStep(),
newMockNormalStep(),
newMockNormalStep(),
}}
s4 := &stepStack{steps: []nestedStep{
newMockNormalStep(),
newMockNormalStep(),
newMockNormalStep(),
newMockNormalStep(),
}}
m := map[*stepStack]struct{}{
s0: {},
s1: {},
s2: {},
s3: {},
s4: {},
}
selected := randomSelect(0, m)
assert.Equal(t, defaultBgExecutingParallel, len(selected))
for _, s := range selected {
_, ok := m[s]
assert.True(t, ok)
}
selected = randomSelect(2, m)
assert.Equal(t, 2, len(selected))
for _, s := range selected {
_, ok := m[s]
assert.True(t, ok)
}
}
func Test_bgStepExecutor_scheduleLoop(t *testing.T) {
bg := newBgStepExecutor(context.Background(),
withSelectStepPolicy(randomSelectPolicy(defaultBgExecutingParallel)),
withBgInterval(time.Millisecond*10))
bg.Start()
n := 20
records := make([]int, 0, n)
steps := make([]*stepStack, 0, n)
for i := 0; i < n; i++ {
var s *stepStack
r := rand.Int() % 3
records = append(records, r)
switch r {
case 0:
s = nil
case 1:
failStep := newMockFailStep()
failStep.err = retry.Unrecoverable(errors.New("error mock Execute"))
s = &stepStack{steps: []nestedStep{
newMockNormalStep(),
failStep,
newMockNormalStep(),
}}
case 2:
s = &stepStack{steps: []nestedStep{
newMockNormalStep(),
newMockNormalStep(),
newMockNormalStep(),
}}
default:
}
steps = append(steps, s)
bg.AddSteps(s)
}
for i, r := range records {
switch r {
case 0:
assert.Nil(t, steps[i])
case 1:
<-steps[i].steps[1].(*mockFailStep).calledChan
assert.True(t, steps[i].steps[1].(*mockFailStep).called)
case 2:
default:
}
}
bg.Stop()
}
func Test_selectByPriorityPolicy(t *testing.T) {
policy := selectByPriorityPolicy(4)
t.Run("select all", func(t *testing.T) {
m := map[*stepStack]struct{}{
{steps: []nestedStep{}}: {},
{steps: []nestedStep{}}: {},
}
selected := policy(m)
assert.Equal(t, 2, len(selected))
})
t.Run("select by priority", func(t *testing.T) {
steps := []nestedStep{
&releaseCollectionStep{},
&releaseCollectionStep{},
&releaseCollectionStep{},
&releaseCollectionStep{},
&releaseCollectionStep{},
}
s1 := &stepStack{steps: steps[0:1]}
s2 := &stepStack{steps: steps[0:2]}
s3 := &stepStack{steps: steps[0:3]}
s4 := &stepStack{steps: steps[0:4]}
s5 := &stepStack{steps: steps[0:5]}
m := map[*stepStack]struct{}{
s1: {},
s2: {},
s3: {},
s4: {},
s5: {},
}
selected := policy(m)
assert.Equal(t, 4, len(selected))
for i := 1; i < len(selected); i++ {
assert.True(t, selected[i].totalPriority() <= selected[i-1].totalPriority())
}
})
}

View File

@ -1,76 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func restoreConfirmGCInterval() {
confirmGCInterval = time.Minute * 20
}
func Test_confirmGCStep_Execute(t *testing.T) {
t.Run("wait for reschedule", func(t *testing.T) {
confirmGCInterval = time.Minute * 1000
defer restoreConfirmGCInterval()
s := &confirmGCStep{lastScheduledTime: time.Now()}
_, err := s.Execute(context.TODO())
assert.Error(t, err)
})
t.Run("GC not finished", func(t *testing.T) {
broker := newMockBroker()
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
return false
}
core := newTestCore(withBroker(broker))
confirmGCInterval = time.Millisecond
defer restoreConfirmGCInterval()
s := newConfirmGCStep(core, 100, 1000)
time.Sleep(confirmGCInterval)
_, err := s.Execute(context.TODO())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
broker := newMockBroker()
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
return true
}
core := newTestCore(withBroker(broker))
confirmGCInterval = time.Millisecond
defer restoreConfirmGCInterval()
s := newConfirmGCStep(core, 100, 1000)
time.Sleep(confirmGCInterval)
_, err := s.Execute(context.TODO())
assert.NoError(t, err)
})
}

View File

@ -72,48 +72,6 @@ func TestLockerKey(t *testing.T) {
}
func TestGetLockerKey(t *testing.T) {
t.Run("alter collection task locker key", func(t *testing.T) {
metaMock := mockrootcoord.NewIMetaTable(t)
metaMock.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, s string, s2 string, u uint64) (*model.Collection, error) {
return &model.Collection{
Name: s2,
CollectionID: 111,
}, nil
})
c := &Core{
meta: metaMock,
}
tt := &alterCollectionTask{
baseTask: baseTask{
core: c,
},
Req: &milvuspb.AlterCollectionRequest{
DbName: "foo",
CollectionName: "bar",
},
}
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true")
})
t.Run("alter collection task locker key by ID", func(t *testing.T) {
metaMock := mockrootcoord.NewIMetaTable(t)
c := &Core{
meta: metaMock,
}
tt := &alterCollectionTask{
baseTask: baseTask{
core: c,
},
Req: &milvuspb.AlterCollectionRequest{
DbName: "foo",
CollectionName: "",
CollectionID: 111,
},
}
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true")
})
t.Run("describe collection task locker key", func(t *testing.T) {
metaMock := mockrootcoord.NewIMetaTable(t)
metaMock.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
@ -196,17 +154,6 @@ func TestGetLockerKey(t *testing.T) {
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false")
})
t.Run("rename collection task locker key", func(t *testing.T) {
tt := &renameCollectionTask{
Req: &milvuspb.RenameCollectionRequest{
DbName: "foo",
OldName: "bar",
NewName: "baz",
},
}
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-true")
})
t.Run("show collection task locker key", func(t *testing.T) {
tt := &showCollectionTask{
Req: &milvuspb.ShowCollectionsRequest{

View File

@ -1,63 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
)
type baseUndoTask struct {
todoStep []nestedStep // steps to execute
undoStep []nestedStep // steps to undo
stepExecutor StepExecutor
}
func newBaseUndoTask(stepExecutor StepExecutor) *baseUndoTask {
return &baseUndoTask{
todoStep: make([]nestedStep, 0),
undoStep: make([]nestedStep, 0),
stepExecutor: stepExecutor,
}
}
func (b *baseUndoTask) AddStep(todoStep, undoStep nestedStep) {
b.todoStep = append(b.todoStep, todoStep)
b.undoStep = append(b.undoStep, undoStep)
}
func (b *baseUndoTask) Execute(ctx context.Context) error {
if len(b.todoStep) != len(b.undoStep) {
return errors.New("todo step and undo step length not equal")
}
for i := 0; i < len(b.todoStep); i++ {
todoStep := b.todoStep[i]
// no children step in normal case.
if _, err := todoStep.Execute(ctx); err != nil {
log.Ctx(ctx).Warn("failed to execute step, trying to undo", zap.Error(err), zap.String("desc", todoStep.Desc()))
undoSteps := b.undoStep[:i]
b.undoStep = nil // let baseUndoTask can be collected.
go b.stepExecutor.AddSteps(&stepStack{undoSteps})
return err
}
}
return nil
}

View File

@ -1,119 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func newTestUndoTask() *baseUndoTask {
stepExecutor := newMockStepExecutor()
stepExecutor.AddStepsFunc = func(s *stepStack) {
// no schedule, execute directly.
s.Execute(context.Background())
}
undoTask := newBaseUndoTask(stepExecutor)
return undoTask
}
func Test_baseUndoTask_Execute(t *testing.T) {
t.Run("should not happen", func(t *testing.T) {
undoTask := newTestUndoTask()
undoTask.todoStep = append(undoTask.todoStep, newMockNormalStep())
err := undoTask.Execute(context.Background())
assert.Error(t, err)
})
t.Run("normal case, no undo step will be called", func(t *testing.T) {
undoTask := newTestUndoTask()
n := 10
todoSteps, undoSteps := make([]nestedStep, 0, n), make([]nestedStep, 0, n)
for i := 0; i < n; i++ {
normalTodoStep := newMockNormalStep()
normalUndoStep := newMockNormalStep()
todoSteps = append(todoSteps, normalTodoStep)
undoSteps = append(undoSteps, normalUndoStep)
}
for i := 0; i < n; i++ {
undoTask.AddStep(todoSteps[i], undoSteps[i])
}
err := undoTask.Execute(context.Background())
assert.NoError(t, err)
// make sure no undo steps will be called.
for _, step := range undoSteps {
assert.False(t, step.(*mockNormalStep).called)
}
})
t.Run("partial error, undo from last finished", func(t *testing.T) {
undoTask := newTestUndoTask()
todoSteps := []nestedStep{
newMockNormalStep(),
newMockFailStep(),
newMockNormalStep(),
}
undoSteps := []nestedStep{
newMockNormalStep(),
newMockNormalStep(),
newMockNormalStep(),
}
l := len(todoSteps)
for i := 0; i < l; i++ {
undoTask.AddStep(todoSteps[i], undoSteps[i])
}
err := undoTask.Execute(context.Background())
assert.Error(t, err)
assert.True(t, todoSteps[0].(*mockNormalStep).called)
assert.True(t, todoSteps[1].(*mockFailStep).called)
assert.False(t, todoSteps[2].(*mockNormalStep).called)
<-undoSteps[0].(*mockNormalStep).calledChan
assert.True(t, undoSteps[0].(*mockNormalStep).called)
assert.False(t, undoSteps[1].(*mockNormalStep).called)
assert.False(t, undoSteps[2].(*mockNormalStep).called)
})
t.Run("partial error, undo meet error also", func(t *testing.T) {
undoTask := newTestUndoTask()
todoSteps := []nestedStep{
newMockNormalStep(),
newMockNormalStep(),
newMockFailStep(),
}
undoSteps := []nestedStep{
newMockNormalStep(),
newMockFailStep(),
newMockNormalStep(),
}
l := len(todoSteps)
for i := 0; i < l; i++ {
undoTask.AddStep(todoSteps[i], undoSteps[i])
}
err := undoTask.Execute(context.Background())
assert.Error(t, err)
assert.True(t, todoSteps[0].(*mockNormalStep).called)
assert.True(t, todoSteps[1].(*mockNormalStep).called)
assert.True(t, todoSteps[2].(*mockFailStep).called)
assert.False(t, undoSteps[0].(*mockNormalStep).called)
<-undoSteps[1].(*mockFailStep).calledChan
assert.True(t, undoSteps[1].(*mockFailStep).called)
assert.False(t, undoSteps[2].(*mockNormalStep).called)
})
}

View File

@ -44,6 +44,9 @@ type Balancer interface {
// ReplicateRole returns the replicate role of the balancer.
ReplicateRole() replicateutil.Role
// WaitUntilWALbasedDDLReady waits until the WAL based DDL is ready.
WaitUntilWALbasedDDLReady(ctx context.Context) error
// RegisterStreamingEnabledNotifier registers a notifier into the balancer.
// If the error is returned, the balancer is closed.
// Otherwise, the following rules are applied:

View File

@ -26,6 +26,7 @@ import (
const (
versionChecker260 = "<2.6.0-dev"
versionChecker265 = "<2.6.5-dev"
)
// RecoverBalancer recover the balancer working.
@ -110,6 +111,20 @@ func (b *balancerImpl) GetLatestWALLocated(ctx context.Context, pchannel string)
return b.channelMetaManager.GetLatestWALLocated(ctx, pchannel)
}
// WaitUntilWALbasedDDLReady waits until the WAL based DDL is ready.
func (b *balancerImpl) WaitUntilWALbasedDDLReady(ctx context.Context) error {
if b.channelMetaManager.IsWALBasedDDLEnabled() {
return nil
}
if err := b.channelMetaManager.WaitUntilStreamingEnabled(ctx); err != nil {
return err
}
if err := b.blockUntilRoleGreaterThanVersion(ctx, typeutil.StreamingNodeRole, versionChecker265); err != nil {
return err
}
return b.channelMetaManager.MarkWALBasedDDLEnabled(ctx)
}
// WatchChannelAssignments watches the balance result.
func (b *balancerImpl) WatchChannelAssignments(ctx context.Context, cb WatchChannelAssignmentsCallback) error {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {
@ -335,20 +350,20 @@ func (b *balancerImpl) checkIfRoleGreaterThan260(ctx context.Context, role strin
func (b *balancerImpl) blockUntilAllNodeIsGreaterThan260AtBackground(ctx context.Context) error {
expectedRoles := []string{typeutil.ProxyRole, typeutil.DataNodeRole, typeutil.QueryNodeRole}
for _, role := range expectedRoles {
if err := b.blockUntilRoleGreaterThan260AtBackground(ctx, role); err != nil {
if err := b.blockUntilRoleGreaterThanVersion(ctx, role, versionChecker260); err != nil {
return err
}
}
return b.channelMetaManager.MarkStreamingHasEnabled(ctx)
}
// blockUntilRoleGreaterThan260AtBackground block until the role is greater than 2.6.0 at background.
func (b *balancerImpl) blockUntilRoleGreaterThan260AtBackground(ctx context.Context, role string) error {
// blockUntilRoleGreaterThanVersion block until the role is greater than 2.6.0 at background.
func (b *balancerImpl) blockUntilRoleGreaterThanVersion(ctx context.Context, role string, versionChecker string) error {
doneErr := errors.New("done")
logger := b.Logger().With(zap.String("role", role))
logger.Info("start to wait that the nodes is greater than 2.6.0")
logger.Info("start to wait that the nodes is greater than version", zap.String("version", versionChecker))
// Check if there's any proxy or data node with version < 2.6.0.
rb := resolver.NewSessionBuilder(resource.Resource().ETCD(), sessionutil.GetSessionPrefixByRole(role), versionChecker260)
rb := resolver.NewSessionBuilder(resource.Resource().ETCD(), sessionutil.GetSessionPrefixByRole(role), versionChecker)
defer rb.Close()
r := rb.Resolver()
@ -360,10 +375,10 @@ func (b *balancerImpl) blockUntilRoleGreaterThan260AtBackground(ctx context.Cont
return nil
})
if err != nil && !errors.Is(err, doneErr) {
logger.Info("fail to wait that the nodes is greater than 2.6.0", zap.Error(err))
logger.Info("fail to wait that the nodes is greater than version", zap.String("version", versionChecker), zap.Error(err))
return err
}
logger.Info("all nodes is greater than 2.6.0 when watching")
logger.Info("all nodes is greater than version when watching", zap.String("version", versionChecker))
return nil
}

View File

@ -23,6 +23,11 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const (
StreamingVersion260 = 1 // streaming version that since 2.6.0, the streaming based WAL is available.
StreamingVersion265 = 2 // streaming version that since 2.6.5, the WAL based DDL is available.
)
var ErrChannelNotExist = errors.New("channel not exist")
type (
@ -184,6 +189,26 @@ func (cm *ChannelManager) IsStreamingEnabledOnce() bool {
return cm.streamingVersion != nil
}
// WaitUntilStreamingEnabled waits until the streaming service is enabled.
func (cm *ChannelManager) WaitUntilStreamingEnabled(ctx context.Context) error {
cm.cond.L.Lock()
for cm.streamingVersion == nil {
if err := cm.cond.Wait(ctx); err != nil {
return err
}
}
cm.cond.L.Unlock()
return nil
}
// IsWALBasedDDLEnabled returns true if the WAL based DDL is enabled.
func (cm *ChannelManager) IsWALBasedDDLEnabled() bool {
cm.cond.L.Lock()
defer cm.cond.L.Unlock()
return cm.streamingVersion != nil && cm.streamingVersion.Version >= StreamingVersion265
}
// ReplicateRole returns the replicate role of the channel manager.
func (cm *ChannelManager) ReplicateRole() replicateutil.Role {
cm.cond.L.Lock()
@ -211,8 +236,12 @@ func (cm *ChannelManager) MarkStreamingHasEnabled(ctx context.Context) error {
cm.cond.L.Lock()
defer cm.cond.L.Unlock()
if cm.streamingVersion != nil {
return nil
}
cm.streamingVersion = &streamingpb.StreamingVersion{
Version: 1,
Version: StreamingVersion260,
}
if err := resource.Resource().StreamingCatalog().SaveVersion(ctx, cm.streamingVersion); err != nil {
@ -232,6 +261,24 @@ func (cm *ChannelManager) MarkStreamingHasEnabled(ctx context.Context) error {
return nil
}
func (cm *ChannelManager) MarkWALBasedDDLEnabled(ctx context.Context) error {
cm.cond.L.Lock()
defer cm.cond.L.Unlock()
if cm.streamingVersion == nil {
return errors.New("streaming service is not enabled, cannot mark WAL based DDL enabled")
}
if cm.streamingVersion.Version >= StreamingVersion265 {
return nil
}
cm.streamingVersion.Version = StreamingVersion265
if err := resource.Resource().StreamingCatalog().SaveVersion(ctx, cm.streamingVersion); err != nil {
cm.Logger().Error("failed to save streaming version", zap.Error(err))
return err
}
return nil
}
// CurrentPChannelsView returns the current view of pchannels.
func (cm *ChannelManager) CurrentPChannelsView() *PChannelView {
cm.cond.L.Lock()

View File

@ -3,6 +3,9 @@ package broadcast
import (
"context"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
@ -30,6 +33,13 @@ func StartBroadcastWithResourceKeys(ctx context.Context, resourceKeys ...message
if err != nil {
return nil, err
}
b, err := balance.GetWithContext(ctx)
if err != nil {
return nil, err
}
if err := b.WaitUntilWALbasedDDLReady(ctx); err != nil {
return nil, errors.Wrap(err, "failed to wait until WAL based DDL ready")
}
return broadcaster.WithResourceKeys(ctx, resourceKeys...)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
@ -37,7 +38,9 @@ func TestAssignmentService(t *testing.T) {
broadcast.ResetBroadcaster()
// Set up the balancer
snmanager.ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WaitUntilWALbasedDDLReady(mock.Anything).Return(nil).Maybe()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()

View File

@ -10,7 +10,11 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
@ -22,6 +26,16 @@ import (
func TestBroadcastService(t *testing.T) {
broadcast.ResetBroadcaster()
snmanager.ResetStreamingNodeManager()
// Set up the balancer
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WaitUntilWALbasedDDLReady(mock.Anything).Return(nil).Maybe()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
})
b.EXPECT().Close().Return().Maybe()
balance.Register(b)
fb := syncutil.NewFuture[broadcaster.Broadcaster]()
mba := mock_broadcaster.NewMockBroadcastAPI(t)

View File

@ -111,3 +111,7 @@ func (impl *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFl
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error {
return impl.wbMgr.SealSegments(context.Background(), msg.VChannel(), msg.Header().FlushedSegmentIds)
}
func (impl *msgHandlerImpl) HandleAlterCollection(ctx context.Context, putCollectionMsg message.ImmutableAlterCollectionMessageV2) error {
return impl.wbMgr.SealSegments(context.Background(), putCollectionMsg.VChannel(), putCollectionMsg.Header().FlushedSegmentIds)
}

View File

@ -40,6 +40,7 @@ func (impl *shardInterceptor) initOpTable() {
message.MessageTypeDelete: impl.handleDeleteMessage,
message.MessageTypeManualFlush: impl.handleManualFlushMessage,
message.MessageTypeSchemaChange: impl.handleSchemaChange,
message.MessageTypeAlterCollection: impl.handleAlterCollection,
message.MessageTypeCreateSegment: impl.handleCreateSegment,
message.MessageTypeFlush: impl.handleFlushSegment,
}
@ -244,6 +245,19 @@ func (impl *shardInterceptor) handleSchemaChange(ctx context.Context, msg messag
return appendOp(ctx, msg)
}
// handleAlterCollection handles the alter collection message.
func (impl *shardInterceptor) handleAlterCollection(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) {
putCollectionMsg := message.MustAsMutableAlterCollectionMessageV2(msg)
header := putCollectionMsg.Header()
segmentIDs, err := impl.shardManager.FlushAndFenceSegmentAllocUntil(header.GetCollectionId(), msg.TimeTick())
if err != nil {
return nil, status.NewUnrecoverableError(err.Error())
}
header.FlushedSegmentIds = segmentIDs
putCollectionMsg.OverwriteHeader(header)
return appendOp(ctx, msg)
}
// handleCreateSegment handles the create segment message.
func (impl *shardInterceptor) handleCreateSegment(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) {
createSegmentMsg := message.MustAsMutableCreateSegmentMessageV2(msg)

View File

@ -337,6 +337,9 @@ func (r *recoveryStorageImpl) handleMessage(msg message.ImmutableMessage) {
case message.MessageTypeSchemaChange:
immutableMsg := message.MustAsImmutableSchemaChangeMessageV2(msg)
r.handleSchemaChange(immutableMsg)
case message.MessageTypeAlterCollection:
immutableMsg := message.MustAsImmutableAlterCollectionMessageV2(msg)
r.handleAlterCollection(immutableMsg)
case message.MessageTypeTimeTick:
// nothing, the time tick message make no recovery operation.
}
@ -505,6 +508,21 @@ func (r *recoveryStorageImpl) handleSchemaChange(msg message.ImmutableSchemaChan
}
}
// handlePutCollection handles the put collection message.
func (r *recoveryStorageImpl) handleAlterCollection(msg message.ImmutableAlterCollectionMessageV2) {
// when put collection happens, we need to flush all segments in the collection.
segments := make(map[int64]struct{}, len(msg.Header().FlushedSegmentIds))
for _, segmentID := range msg.Header().FlushedSegmentIds {
segments[segmentID] = struct{}{}
}
r.flushSegments(msg, segments)
// persist the schema change into recovery info.
if vchannelInfo, ok := r.vchannels[msg.VChannel()]; ok {
vchannelInfo.ObserveAlterCollection(msg)
}
}
// detectInconsistency detects the inconsistency in the recovery storage.
func (r *recoveryStorageImpl) detectInconsistency(msg message.ImmutableMessage, reason string, extra ...zap.Field) {
fields := make([]zap.Field, 0, len(extra)+2)

View File

@ -139,6 +139,26 @@ func (info *vchannelRecoveryInfo) ObserveSchemaChange(msg message.ImmutableSchem
info.dirty = true
}
// ObservePutCollection is called when a put collection message is observed.
func (info *vchannelRecoveryInfo) ObserveAlterCollection(msg message.ImmutableAlterCollectionMessageV2) {
if msg.TimeTick() < info.meta.CheckpointTimeTick {
// the txn message will share the same time tick.
// (although the flush operation is not a txn message)
// so we only filter the time tick is less than the checkpoint time tick.
// Consistent state is guaranteed by the recovery storage's mutex.
return
}
if messageutil.IsSchemaChange(msg.Header()) {
info.meta.CollectionInfo.Schemas = append(info.meta.CollectionInfo.Schemas, &streamingpb.CollectionSchemaOfVChannel{
Schema: msg.MustBody().Updates.Schema,
State: streamingpb.VChannelSchemaState_VCHANNEL_SCHEMA_STATE_NORMAL,
CheckpointTimeTick: msg.TimeTick(),
})
}
info.meta.CheckpointTimeTick = msg.TimeTick()
info.dirty = true
}
// ObserveDropCollection is called when a drop collection message is observed.
func (info *vchannelRecoveryInfo) ObserveDropCollection(msg message.ImmutableDropCollectionMessageV1) {
if msg.TimeTick() < info.meta.CheckpointTimeTick {

View File

@ -37,3 +37,15 @@ func (pairs KeyValuePairs) Equal(other KeyValuePairs) bool {
func CloneKeyValuePairs(pairs KeyValuePairs) KeyValuePairs {
return pairs.Clone()
}
// NewKeyValuePairs creates a new KeyValuePairs from a map[string]string.
func NewKeyValuePairs(kvs map[string]string) KeyValuePairs {
pairs := make(KeyValuePairs, 0, len(kvs))
for key, value := range kvs {
pairs = append(pairs, &commonpb.KeyValuePair{
Key: key,
Value: value,
})
}
return pairs
}

View File

@ -6,5 +6,5 @@ import semver "github.com/blang/semver/v4"
var Version semver.Version
func init() {
Version = semver.MustParse("2.6.4")
Version = semver.MustParse("2.6.5-dev")
}

View File

@ -54,6 +54,7 @@ func (b *builderImpl) getConsumerConfig() kafka.ConfigMap {
config := &paramtable.Get().KafkaCfg
consumerConfig := getBasicConfig(config)
consumerConfig.SetKey("allow.auto.create.topics", true)
consumerConfig.SetKey("auto.offset.reset", "earliest")
for k, v := range config.ConsumerExtraConfig.GetValue() {
consumerConfig.SetKey(k, v)
}

View File

@ -1106,7 +1106,7 @@ func TestRenameCollectionAdvanced(t *testing.T) {
// rename: old name same with new name
err := mc.RenameCollection(ctx, client.NewRenameCollectionOption(name1, name1))
common.CheckErr(t, err, false, "duplicated new collection name")
common.CheckErr(t, err, false, "collection name or database name should be different")
// rename to a existed name
err = mc.RenameCollection(ctx, client.NewRenameCollectionOption(name1, name2))

View File

@ -527,10 +527,8 @@ class TestUtilityParams(TestcaseBase):
old_collection_name = collection_w.name
self.utility_wrap.rename_collection(old_collection_name, old_collection_name,
check_task=CheckTasks.err_res,
check_items={"err_code": 65535,
"err_msg": "duplicated new collection name {} in database default"
" with other collection name or"
" alias".format(collection_w.name)})
check_items={"err_code": 1100,
"err_msg": "collection name or database name should be different"})
@pytest.mark.tags(CaseLabel.L1)
def test_rename_collection_existed_collection_alias(self):