enhance: cherry pick patch of new DDL framework and CDC 2 (#45241)

issue: #43897, #44123
pr: #45224
also pick pr: #45216,#45154,#45033,#45145,#45092,#45058,#45029

enhance: Close channel replicator more gracefully (#45029)

issue: https://github.com/milvus-io/milvus/issues/44123

enhance: Show create time for import job (#45058)

issue: https://github.com/milvus-io/milvus/issues/45056

fix: wal state may be unconsistent after recovering from crash (#45092)

issue: #45088, #45086

- Message on control channel should trigger the checkpoint update.
- LastConfrimedMessageID should be recovered from the minimum of
checkpoint or the LastConfirmedMessageID of uncommitted txn.
- Add more log info for wal debugging.

fix: make ack of broadcaster cannot canceled by client (#45145)

issue: #45141

- make ack of broadcaster cannot canceled by rpc.
- make clone for assignment snapshot of wal balancer.
- add server id for GetReplicateCheckpoint to avoid failure.

enhance: support collection and index with WAL-based DDL framework
(#45033)

issue: #43897

- Part of collection/index related DDL is implemented by WAL-based DDL
framework now.
- Support following message type in wal, CreateCollection,
DropCollection, CreatePartition, DropPartition, CreateIndex, AlterIndex,
DropIndex.
- Part of collection/index related DDL can be synced by new CDC now.
- Refactor some UT for collection/index DDL.
- Add Tombstone scheduler to manage the tombstone GC for collection or
partition meta.
- Move the vchannel allocation into streaming pchannel manager.

enhance: support load/release collection/partition with WAL-based DDL
framework (#45154)

issue: #43897

- Load/Release collection/partition is implemented by WAL-based DDL
framework now.
- Support AlterLoadConfig/DropLoadConfig in wal now.
- Load/Release operation can be synced by new CDC now.
- Refactor some UT for load/release DDL.

enhance: Don't start cdc by default (#45216)

issue: https://github.com/milvus-io/milvus/issues/44123


fix: unrecoverable when replicate from old (#45224)

issue: #44962

---------

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
Signed-off-by: chyezh <chyezh@outlook.com>
Co-authored-by: yihao.dai <yihao.dai@zilliz.com>
This commit is contained in:
Zhen Ye 2025-11-04 01:35:33 +08:00 committed by GitHub
parent cefdd25ef7
commit 02e2170601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
136 changed files with 5364 additions and 7507 deletions

View File

@ -6,7 +6,7 @@ require (
github.com/blang/semver/v4 v4.0.0
github.com/cockroachdb/errors v1.9.1
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.4-0.20251013093953-f3e0a710c654
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.5-0.20251103083929-99dbd46f10b7
github.com/milvus-io/milvus/pkg/v2 v2.6.3
github.com/quasilyte/go-ruleguard/dsl v0.3.23
github.com/samber/lo v1.27.0

View File

@ -322,8 +322,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.4-0.20251013093953-f3e0a710c654 h1:p604i9izeR8eWrQhOFmcmxhNhYlsvTkkmph4b2GbOeg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.4-0.20251013093953-f3e0a710c654/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.5-0.20251103083929-99dbd46f10b7 h1:AxECtO0R/G622zMHniIh11JjL/nvu84xQSXI6KQSxRs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.5-0.20251103083929-99dbd46f10b7/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus/pkg/v2 v2.6.3 h1:WDf4mXFWL5Sk/V87yLwRKq24MYMkjS2YA6qraXbLbJA=
github.com/milvus-io/milvus/pkg/v2 v2.6.3/go.mod h1:49umaGHK9nKHJNtgBlF/iB24s1sZ/SG5/Q7iLj/Gc14=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=

View File

@ -1312,7 +1312,7 @@ streaming:
# it also determine the depth of depth first search method that is used to find the best balance result, 3 by default
rebalanceMaxStep: 3
walBroadcaster:
concurrencyRatio: 1 # The concurrency ratio based on number of CPU for wal broadcaster, 1 by default.
concurrencyRatio: 4 # The concurrency ratio based on number of CPU for wal broadcaster, 4 by default.
txn:
defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default
walWriteAheadBuffer:

2
go.mod
View File

@ -21,7 +21,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/klauspost/compress v1.17.9
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.4-0.20251013093953-f3e0a710c654
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.5-0.20251103083929-99dbd46f10b7
github.com/minio/minio-go/v7 v7.0.73
github.com/panjf2000/ants/v2 v2.11.3 // indirect
github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 // indirect

4
go.sum
View File

@ -794,8 +794,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20250318084424-114f4050c3a6 h1:YHMFI6L
github.com/milvus-io/cgosymbolizer v0.0.0-20250318084424-114f4050c3a6/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8=
github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.4-0.20251013093953-f3e0a710c654 h1:p604i9izeR8eWrQhOFmcmxhNhYlsvTkkmph4b2GbOeg=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.4-0.20251013093953-f3e0a710c654/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.5-0.20251103083929-99dbd46f10b7 h1:AxECtO0R/G622zMHniIh11JjL/nvu84xQSXI6KQSxRs=
github.com/milvus-io/milvus-proto/go-api/v2 v2.6.5-0.20251103083929-99dbd46f10b7/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=

View File

@ -14,6 +14,9 @@ packages:
Broadcast:
Local:
Scanner:
github.com/milvus-io/milvus/internal/rootcoord/tombstone:
interfaces:
TombstoneSweeper:
github.com/milvus-io/milvus/internal/streamingcoord/server/balancer:
interfaces:
Balancer:

View File

@ -79,7 +79,18 @@ func (r *channelReplicator) StartReplication() {
logger := log.With(zap.String("key", r.channel.Key), zap.Int64("modRevision", r.channel.ModRevision))
logger.Info("start replicate channel")
go func() {
defer r.asyncNotifier.Finish(struct{}{})
defer func() {
if r.streamClient != nil {
r.streamClient.Close()
}
if r.msgScanner != nil {
r.msgScanner.Close()
}
if r.targetClient != nil {
r.targetClient.Close(r.asyncNotifier.Context())
}
r.asyncNotifier.Finish(struct{}{})
}()
INIT_LOOP:
for {
select {
@ -201,14 +212,5 @@ func (r *channelReplicator) getReplicateCheckpoint() (*utility.ReplicateCheckpoi
func (r *channelReplicator) StopReplication() {
r.asyncNotifier.Cancel()
r.asyncNotifier.BlockUntilFinish() // wait for the start goroutine to finish
if r.targetClient != nil {
r.targetClient.Close(r.asyncNotifier.Context())
}
if r.streamClient != nil {
r.streamClient.Close()
}
if r.msgScanner != nil {
r.msgScanner.Close()
}
r.asyncNotifier.BlockUntilFinish()
}

View File

@ -19,6 +19,7 @@ package replicatestream
import (
"context"
"strconv"
"sync"
"testing"
"time"
@ -103,6 +104,7 @@ func TestReplicateStreamClient_Replicate(t *testing.T) {
assert.Eventually(t, func() bool {
return replicateClient.(*replicateStreamClient).pendingMessages.Len() == 0
}, time.Second, 100*time.Millisecond)
mockStreamClient.Close()
}
func TestReplicateStreamClient_Replicate_ContextCanceled(t *testing.T) {
@ -218,6 +220,7 @@ func TestReplicateStreamClient_Reconnect(t *testing.T) {
assert.Eventually(t, func() bool {
return replicateClient.(*replicateStreamClient).pendingMessages.Len() == 0
}, time.Second, 100*time.Millisecond)
mockStreamClient.Close()
}
// mockReplicateStreamClient implements the milvuspb.MilvusService_CreateReplicateStreamClient interface
@ -232,7 +235,8 @@ type mockReplicateStreamClient struct {
t *testing.T
timeout time.Duration
closeCh chan struct{}
closeOnce sync.Once
closeCh chan struct{}
}
func newMockReplicateStreamClient(t *testing.T) *mockReplicateStreamClient {
@ -242,6 +246,7 @@ func newMockReplicateStreamClient(t *testing.T) *mockReplicateStreamClient {
t: t,
timeout: 10 * time.Second,
closeCh: make(chan struct{}, 1),
closeOnce: sync.Once{},
}
}
@ -311,10 +316,18 @@ func (m *mockReplicateStreamClient) Trailer() metadata.MD {
}
func (m *mockReplicateStreamClient) CloseSend() error {
close(m.closeCh)
m.closeOnce.Do(func() {
close(m.closeCh)
})
return nil
}
func (m *mockReplicateStreamClient) Context() context.Context {
return context.Background()
}
func (m *mockReplicateStreamClient) Close() {
m.closeOnce.Do(func() {
close(m.closeCh)
})
}

View File

@ -78,6 +78,15 @@ func (s *StreamingNodeManager) GetBalancer() balancer.Balancer {
return b
}
// AllocVirtualChannels allocates virtual channels for a collection.
func (s *StreamingNodeManager) AllocVirtualChannels(ctx context.Context, param balancer.AllocVChannelParam) ([]string, error) {
balancer, err := balance.GetWithContext(ctx)
if err != nil {
return nil, err
}
return balancer.AllocVirtualChannels(ctx, param)
}
// GetLatestWALLocated returns the server id of the node that the wal of the vChannel is located.
// Return -1 and error if the vchannel is not found or context is canceled.
func (s *StreamingNodeManager) GetLatestWALLocated(ctx context.Context, vchannel string) (int64, error) {

View File

@ -627,11 +627,12 @@ func (s *ClusteringCompactionTaskSuite) TestProcessIndexingState() {
task := s.generateBasicTask(false)
task.updateAndSaveTaskMeta(setState(datapb.CompactionTaskState_indexing))
indexReq := &indexpb.CreateIndexRequest{
CollectionID: 1,
}
task.updateAndSaveTaskMeta(setResultSegments([]int64{10, 11}))
_, err := s.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 3, false)
err := s.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 1,
FieldID: 3,
IndexID: 3,
})
s.NoError(err)
s.False(task.Process())
@ -641,10 +642,11 @@ func (s *ClusteringCompactionTaskSuite) TestProcessIndexingState() {
s.Run("collection has index, segment indexed", func() {
task := s.generateBasicTask(false)
task.updateAndSaveTaskMeta(setState(datapb.CompactionTaskState_indexing))
indexReq := &indexpb.CreateIndexRequest{
err := s.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 1,
}
_, err := s.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 3, false)
FieldID: 3,
IndexID: 3,
})
s.NoError(err)
s.meta.indexMeta.updateSegmentIndex(&model.SegmentIndex{

View File

@ -0,0 +1,59 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
// RegisterDDLCallbacks registers the ddl callbacks.
func RegisterDDLCallbacks(s *Server) {
ddlCallback := &DDLCallbacks{
Server: s,
}
ddlCallback.registerIndexCallbacks()
}
type DDLCallbacks struct {
*Server
}
func (c *DDLCallbacks) registerIndexCallbacks() {
registry.RegisterCreateIndexV2AckCallback(c.createIndexV2AckCallback)
registry.RegisterAlterIndexV2AckCallback(c.alterIndexV2AckCallback)
registry.RegisterDropIndexV2AckCallback(c.dropIndexV2Callback)
}
// startBroadcastWithCollectionID starts a broadcast with collection name.
func (s *Server) startBroadcastWithCollectionID(ctx context.Context, collectionID int64) (broadcaster.BroadcastAPI, error) {
coll, err := s.broker.DescribeCollectionInternal(ctx, collectionID)
if err != nil {
return nil, err
}
dbName := coll.GetDbName()
collectionName := coll.GetCollectionName()
broadcaster, err := broadcast.StartBroadcastWithResourceKeys(ctx, message.NewSharedDBNameResourceKey(dbName), message.NewExclusiveCollectionNameResourceKey(dbName, collectionName))
if err != nil {
return nil, err
}
return broadcaster, nil
}

View File

@ -0,0 +1,35 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"github.com/samber/lo"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
func (s *DDLCallbacks) alterIndexV2AckCallback(ctx context.Context, result message.BroadcastResultAlterIndexMessageV2) error {
indexes := result.Message.MustBody().FieldIndexes
indexModels := lo.Map(indexes, func(index *indexpb.FieldIndex, _ int) *model.Index {
return model.UnmarshalIndexModel(index)
})
return s.meta.indexMeta.AlterIndex(ctx, indexModels...)
}

View File

@ -0,0 +1,36 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
func (s *DDLCallbacks) createIndexV2AckCallback(ctx context.Context, result message.BroadcastResultCreateIndexMessageV2) error {
index := result.Message.MustBody().FieldIndex
if err := s.meta.indexMeta.CreateIndex(ctx, model.UnmarshalIndexModel(index)); err != nil {
return err
}
select {
case s.notifyIndexChan <- index.IndexInfo.CollectionID:
default:
}
return nil
}

View File

@ -0,0 +1,28 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package datacoord
import (
"context"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
func (s *DDLCallbacks) dropIndexV2Callback(ctx context.Context, result message.BroadcastResultDropIndexMessageV2) error {
header := result.Message.Header()
return s.meta.indexMeta.MarkIndexAsDeleted(ctx, header.GetCollectionId(), header.GetIndexIds())
}

View File

@ -16,7 +16,6 @@ import (
mocks2 "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -41,11 +40,11 @@ func TestGetQueryVChanPositionsRetrieveM2N(t *testing.T) {
},
},
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 1,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
require.NoError(t, err)
segArgs := []struct {
@ -153,12 +152,11 @@ func TestGetQueryVChanPositions(t *testing.T) {
},
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
assert.NoError(t, err)
s1 := &datapb.SegmentInfo{
@ -337,11 +335,11 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) {
ID: 0,
Schema: schema,
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
assert.NoError(t, err)
c := &datapb.SegmentInfo{
ID: 1,
@ -406,11 +404,11 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) {
ID: 0,
Schema: schema,
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
assert.NoError(t, err)
a := &datapb.SegmentInfo{
ID: 99,
@ -491,11 +489,11 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) {
ID: 0,
Schema: schema,
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
assert.NoError(t, err)
c := &datapb.SegmentInfo{
ID: 1,
@ -600,11 +598,11 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) {
Partitions: []int64{0},
Schema: schema,
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
assert.NoError(t, err)
seg1 := &datapb.SegmentInfo{
ID: 1,
@ -980,11 +978,11 @@ func TestGetQueryVChanPositions_Retrieve_unIndexed(t *testing.T) {
Partitions: []int64{0},
Schema: schema,
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
assert.NoError(t, err)
seg1 := &datapb.SegmentInfo{
ID: 1,
@ -1184,11 +1182,11 @@ func TestGetCurrentSegmentsView(t *testing.T) {
Schema: schema,
})
indexReq := &indexpb.CreateIndexRequest{
err := svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err := svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 1, false)
IndexID: 1,
})
assert.NoError(t, err)
seg1 := &datapb.SegmentInfo{
ID: 1,

View File

@ -355,7 +355,12 @@ func checkParams(fieldIndex *model.Index, req *indexpb.CreateIndexRequest) bool
func (m *indexMeta) CanCreateIndex(req *indexpb.CreateIndexRequest, isJson bool) (UniqueID, error) {
m.fieldIndexLock.RLock()
defer m.fieldIndexLock.RUnlock()
return m.canCreateIndex(req, isJson)
indexID, err := m.canCreateIndex(req, isJson)
if err != nil {
return 0, err
}
return indexID, nil
}
func (m *indexMeta) canCreateIndex(req *indexpb.CreateIndexRequest, isJson bool) (UniqueID, error) {
@ -424,37 +429,10 @@ func (m *indexMeta) HasSameReq(req *indexpb.CreateIndexRequest) (bool, UniqueID)
return false, 0
}
func (m *indexMeta) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest, allocatedIndexID typeutil.UniqueID, isJson bool) (UniqueID, error) {
func (m *indexMeta) CreateIndex(ctx context.Context, index *model.Index) error {
m.fieldIndexLock.Lock()
defer m.fieldIndexLock.Unlock()
indexID, err := m.canCreateIndex(req, isJson)
if err != nil {
return indexID, err
}
if indexID == 0 {
indexID = allocatedIndexID
} else {
return indexID, nil
}
// exclude the mmap.enable param, because it will be conflicted with the index's mmap.enable param
typeParams := DeleteParams(req.GetTypeParams(), []string{common.MmapEnabledKey})
index := &model.Index{
CollectionID: req.GetCollectionID(),
FieldID: req.GetFieldID(),
IndexID: indexID,
IndexName: req.GetIndexName(),
TypeParams: typeParams,
IndexParams: req.GetIndexParams(),
CreateTime: req.GetTimestamp(),
IsAutoIndex: req.GetIsAutoIndex(),
UserIndexParams: req.GetUserIndexParams(),
}
if err := ValidateIndexParams(index); err != nil {
return indexID, err
}
log.Ctx(ctx).Info("meta update: CreateIndex", zap.Int64("collectionID", index.CollectionID),
zap.Int64("fieldID", index.FieldID), zap.Int64("indexID", index.IndexID), zap.String("indexName", index.IndexName))
@ -462,13 +440,13 @@ func (m *indexMeta) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReq
log.Ctx(ctx).Error("meta update: CreateIndex save meta fail", zap.Int64("collectionID", index.CollectionID),
zap.Int64("fieldID", index.FieldID), zap.Int64("indexID", index.IndexID),
zap.String("indexName", index.IndexName), zap.Error(err))
return indexID, err
return err
}
m.updateCollectionIndex(index)
log.Ctx(ctx).Info("meta update: CreateIndex success", zap.Int64("collectionID", index.CollectionID),
zap.Int64("fieldID", index.FieldID), zap.Int64("indexID", index.IndexID), zap.String("indexName", index.IndexName))
return indexID, nil
return nil
}
func (m *indexMeta) AlterIndex(ctx context.Context, indexes ...*model.Index) error {
@ -661,6 +639,17 @@ func (m *indexMeta) MarkIndexAsDeleted(ctx context.Context, collID UniqueID, ind
m.fieldIndexLock.Lock()
defer m.fieldIndexLock.Unlock()
if len(indexIDs) == 0 {
// drop all indexes if indexIDs is empty.
indexIDs = make([]UniqueID, 0, len(m.indexes[collID]))
for indexID, index := range m.indexes[collID] {
if index.IsDeleted {
continue
}
indexIDs = append(indexIDs, indexID)
}
}
fieldIndexes, ok := m.indexes[collID]
if !ok {
return nil

View File

@ -265,13 +265,25 @@ func TestMeta_CanCreateIndex(t *testing.T) {
IsAutoIndex: false,
UserIndexParams: userIndexParams,
}
indexModel := &model.Index{
CollectionID: req.CollectionID,
FieldID: req.FieldID,
IndexID: indexID,
IndexName: req.IndexName,
IsDeleted: false,
CreateTime: req.Timestamp,
TypeParams: req.TypeParams,
IndexParams: req.IndexParams,
IsAutoIndex: req.IsAutoIndex,
UserIndexParams: userIndexParams,
}
t.Run("can create index", func(t *testing.T) {
tmpIndexID, err := m.CanCreateIndex(req, false)
assert.NoError(t, err)
assert.Equal(t, int64(0), tmpIndexID)
indexID, err = m.CreateIndex(context.TODO(), req, indexID, false)
err = m.CreateIndex(context.TODO(), indexModel)
assert.NoError(t, err)
tmpIndexID, err = m.CanCreateIndex(req, false)
@ -465,8 +477,18 @@ func TestMeta_CreateIndex(t *testing.T) {
IsAutoIndex: false,
UserIndexParams: indexParams,
}
allocatedID := UniqueID(3)
indexModel := &model.Index{
CollectionID: req.CollectionID,
FieldID: req.FieldID,
IndexID: allocatedID,
IndexName: req.IndexName,
TypeParams: req.TypeParams,
IndexParams: req.IndexParams,
CreateTime: req.Timestamp,
IsAutoIndex: req.IsAutoIndex,
UserIndexParams: req.UserIndexParams,
}
t.Run("success", func(t *testing.T) {
sc := catalogmocks.NewDataCoordCatalog(t)
@ -476,7 +498,7 @@ func TestMeta_CreateIndex(t *testing.T) {
).Return(nil)
m := newSegmentIndexMeta(sc)
_, err := m.CreateIndex(context.TODO(), req, allocatedID, false)
err := m.CreateIndex(context.TODO(), indexModel)
assert.NoError(t, err)
})
@ -488,7 +510,8 @@ func TestMeta_CreateIndex(t *testing.T) {
).Return(errors.New("fail"))
m := newSegmentIndexMeta(ec)
_, err := m.CreateIndex(context.TODO(), req, 4, false)
indexModel.IndexID = 4
err := m.CreateIndex(context.TODO(), indexModel)
assert.Error(t, err)
})
}

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
@ -37,6 +38,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metautil"
@ -71,7 +73,7 @@ func (s *Server) getFieldNameByID(schema *schemapb.CollectionSchema, fieldID int
func (s *Server) getSchema(ctx context.Context, collID int64) (*schemapb.CollectionSchema, error) {
resp, err := s.broker.DescribeCollectionInternal(ctx, collID)
if err != nil {
if err := merr.CheckRPCCall(resp.Status, err); err != nil {
return nil, err
}
return resp.GetSchema(), nil
@ -140,10 +142,18 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques
}
metrics.IndexRequestCounter.WithLabelValues(metrics.TotalLabel).Inc()
schema, err := s.getSchema(ctx, req.GetCollectionID())
// Create a new broadcaster for the collection.
broadcaster, err := s.startBroadcastWithCollectionID(ctx, req.GetCollectionID())
if err != nil {
return merr.Status(err), nil
}
defer broadcaster.Close()
coll, err := s.broker.DescribeCollectionInternal(ctx, req.GetCollectionID())
if err := merr.CheckRPCCall(coll.Status, err); err != nil {
return merr.Status(err), nil
}
schema := coll.GetSchema()
if !FieldExists(schema, req.GetFieldID()) {
return merr.Status(merr.WrapErrFieldNotFound(req.GetFieldID())), nil
@ -199,30 +209,59 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques
}
}
allocateIndexID, err := s.allocator.AllocID(ctx)
indexID, err := s.meta.indexMeta.CanCreateIndex(req, isJson)
if err != nil {
log.Warn("failed to alloc indexID", zap.Error(err))
log.Error("Check CanCreateIndex fail", zap.Error(err))
metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if indexID == 0 {
if indexID, err = s.allocator.AllocID(ctx); err != nil {
log.Warn("failed to alloc indexID", zap.Error(err))
metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
}
// Get flushed segments and create index
indexID, err := s.meta.indexMeta.CreateIndex(ctx, req, allocateIndexID, isJson)
if err != nil {
log.Error("CreateIndex fail",
zap.Int64("fieldID", req.GetFieldID()), zap.String("indexName", req.GetIndexName()), zap.Error(err))
// exclude the mmap.enable param, because it will be conflicted with the index's mmap.enable param
typeParams := DeleteParams(req.GetTypeParams(), []string{common.MmapEnabledKey})
index := &model.Index{
CollectionID: req.GetCollectionID(),
FieldID: req.GetFieldID(),
IndexID: indexID,
IndexName: req.GetIndexName(),
TypeParams: typeParams,
IndexParams: req.GetIndexParams(),
CreateTime: req.GetTimestamp(),
IsAutoIndex: req.GetIsAutoIndex(),
UserIndexParams: req.GetUserIndexParams(),
}
// Validate the index params.
if err := ValidateIndexParams(index); err != nil {
return nil, err
}
if _, err = broadcaster.Broadcast(ctx, message.NewCreateIndexMessageBuilderV2().
WithHeader(&message.CreateIndexMessageHeader{
DbId: coll.GetDbId(),
CollectionId: req.GetCollectionID(),
FieldId: req.GetFieldID(),
IndexId: indexID,
IndexName: req.GetIndexName(),
}).
WithBody(&message.CreateIndexMessageBody{
FieldIndex: model.MarshalIndexModel(index),
}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}).
MustBuildBroadcast(),
); err != nil {
log.Error("CreateIndex fail", zap.Error(err))
metrics.IndexRequestCounter.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
select {
case s.notifyIndexChan <- req.GetCollectionID():
default:
}
log.Info("CreateIndex successfully",
zap.String("IndexName", req.GetIndexName()), zap.Int64("fieldID", req.GetFieldID()),
zap.Int64("IndexID", indexID))
zap.String("IndexName", index.IndexName), zap.Int64("fieldID", index.FieldID),
zap.Int64("IndexID", index.IndexID))
metrics.IndexRequestCounter.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
@ -322,6 +361,12 @@ func (s *Server) AlterIndex(ctx context.Context, req *indexpb.AlterIndexRequest)
return merr.Status(err), nil
}
broadcaster, err := s.startBroadcastWithCollectionID(ctx, req.GetCollectionID())
if err != nil {
return merr.Status(err), nil
}
defer broadcaster.Close()
indexes := s.meta.indexMeta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName())
if len(indexes) == 0 {
err := merr.WrapErrIndexNotFound(req.GetIndexName())
@ -397,12 +442,27 @@ func (s *Server) AlterIndex(ctx context.Context, req *indexpb.AlterIndexRequest)
}
}
err = s.meta.indexMeta.AlterIndex(ctx, indexes...)
if err != nil {
log.Warn("failed to alter index", zap.Error(err))
indexIDs := lo.Map(indexes, func(index *model.Index, _ int) int64 {
return index.IndexID
})
msg := message.NewAlterIndexMessageBuilderV2().
WithHeader(&message.AlterIndexMessageHeader{
CollectionId: req.GetCollectionID(),
IndexIds: indexIDs,
}).
WithBody(&message.AlterIndexMessageBody{
FieldIndexes: lo.Map(indexes, func(index *model.Index, _ int) *indexpb.FieldIndex {
return model.MarshalIndexModel(index)
}),
}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
log.Warn("failed to broadcast alter index message", zap.Error(err))
return merr.Status(err), nil
}
log.Info("broadcast alter index message successfully", zap.Int64("collectionID", req.GetCollectionID()), zap.Int64s("indexIDs", indexIDs))
return merr.Success(), nil
}
@ -879,6 +939,21 @@ func (s *Server) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (
return merr.Status(err), nil
}
// Compatibility logic. To prevent the index on the corresponding segments
// from being dropped at the same time when dropping_partition in version 2.1
if len(req.PartitionIDs) > 0 {
log.Warn("drop index on partition is deprecated, please use drop index on collection instead",
zap.Int64s("partitionIDs", req.GetPartitionIDs()))
return merr.Success(), nil
}
// Create a new broadcaster for the collection.
broadcaster, err := s.startBroadcastWithCollectionID(ctx, req.GetCollectionID())
if err != nil {
return merr.Status(err), nil
}
defer broadcaster.Close()
indexes := s.meta.indexMeta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName())
if len(indexes) == 0 {
log.Info(fmt.Sprintf("there is no index on collection: %d with the index name: %s", req.CollectionID, req.IndexName))
@ -918,18 +993,21 @@ func (s *Server) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (
for _, index := range indexes {
indexIDs = append(indexIDs, index.IndexID)
}
// Compatibility logic. To prevent the index on the corresponding segments
// from being dropped at the same time when dropping_partition in version 2.1
if len(req.GetPartitionIDs()) == 0 {
// drop collection index
err := s.meta.indexMeta.MarkIndexAsDeleted(ctx, req.GetCollectionID(), indexIDs)
if err != nil {
log.Warn("DropIndex fail", zap.String("indexName", req.IndexName), zap.Error(err))
return merr.Status(err), nil
}
}
log.Debug("DropIndex success", zap.Int64s("partitionIDs", req.GetPartitionIDs()),
msg := message.NewDropIndexMessageBuilderV2().
WithHeader(&message.DropIndexMessageHeader{
CollectionId: req.GetCollectionID(),
IndexIds: indexIDs,
}).
WithBody(&message.DropIndexMessageBody{}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
log.Warn("failed to broadcast drop index message", zap.Error(err))
return merr.Status(err), nil
}
log.Info("DropIndex success", zap.Int64s("partitionIDs", req.GetPartitionIDs()),
zap.String("indexName", req.GetIndexName()), zap.Int64s("indexIDs", indexIDs))
return merr.Success(), nil
}

View File

@ -18,7 +18,9 @@ package datacoord
import (
"context"
"fmt"
"math"
"math/rand"
"strconv"
"testing"
"time"
@ -31,31 +33,99 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/broker"
"github.com/milvus-io/milvus/internal/distributed/streaming"
mockkv "github.com/milvus-io/milvus/internal/kv/mocks"
"github.com/milvus-io/milvus/internal/metastore/kv/datacoord"
catalogmocks "github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func initStreamingSystem(t *testing.T) {
registry.ResetRegistration()
wal := mock_streaming.NewMockWALAccesser(t)
wal.EXPECT().ControlChannel().Return(funcutil.GetControlChannel("by-dev-rootcoord-dml_0")).Maybe()
streaming.SetWALForTest(wal)
bapi := mock_broadcaster.NewMockBroadcastAPI(t)
bapi.EXPECT().Broadcast(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, msg message.BroadcastMutableMessage) (*types.BroadcastAppendResult, error) {
results := make(map[string]*message.AppendResult)
for _, vchannel := range msg.BroadcastHeader().VChannels {
results[vchannel] = &message.AppendResult{
MessageID: rmq.NewRmqID(1),
TimeTick: tsoutil.ComposeTSByTime(time.Now(), 0),
LastConfirmedMessageID: rmq.NewRmqID(1),
}
}
retry.Do(context.Background(), func() error {
return registry.CallMessageAckCallback(context.Background(), msg, results)
}, retry.AttemptAlways())
return &types.BroadcastAppendResult{}, nil
}).Maybe()
bapi.EXPECT().Close().Return().Maybe()
mb := mock_broadcaster.NewMockBroadcaster(t)
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().Close().Return().Maybe()
broadcast.Release()
broadcast.ResetBroadcaster()
broadcast.Register(mb)
snmanager.ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().AllocVirtualChannels(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, param balancer.AllocVChannelParam) ([]string, error) {
vchannels := make([]string, 0, param.Num)
for i := 0; i < param.Num; i++ {
vchannels = append(vchannels, funcutil.GetVirtualChannel(fmt.Sprintf("by-dev-rootcoord-dml_%d_100v0", i), param.CollectionID, i))
}
return vchannels, nil
}).Maybe()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, callback balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
}).Maybe()
b.EXPECT().Close().Return().Maybe()
balance.Register(b)
channel.ResetStaticPChannelStatsManager()
channel.RecoverPChannelStatsManager([]string{})
}
func TestServerId(t *testing.T) {
s := &Server{session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: 0}}}
assert.Equal(t, int64(0), s.serverID())
}
func TestServer_CreateIndex(t *testing.T) {
initStreamingSystem(t)
var (
collID = UniqueID(1)
fieldID = UniqueID(10)
@ -109,6 +179,7 @@ func TestServer_CreateIndex(t *testing.T) {
allocator: mock0Allocator,
notifyIndexChan: make(chan UniqueID, 1),
}
RegisterDDLCallbacks(s)
s.stateCode.Store(commonpb.StateCode_Healthy)
@ -292,8 +363,18 @@ func TestServer_CreateIndex(t *testing.T) {
t.Run("save index fail", func(t *testing.T) {
metakv := mockkv.NewMetaKv(t)
metakv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed")).Maybe()
metakv.EXPECT().MultiSave(mock.Anything, mock.Anything).Return(errors.New("failed")).Maybe()
metakv.EXPECT().Save(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, key string, value string) error {
if rand.Intn(3) == 0 {
return errors.New("failed")
}
return nil
}).Maybe()
metakv.EXPECT().MultiSave(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, kvs map[string]string) error {
if rand.Intn(3) == 0 {
return errors.New("failed")
}
return nil
}).Maybe()
s.meta.indexMeta.indexes = map[UniqueID]map[UniqueID]*model.Index{}
s.meta.catalog = &datacoord.Catalog{MetaKv: metakv}
s.meta.indexMeta.catalog = s.meta.catalog
@ -304,11 +385,12 @@ func TestServer_CreateIndex(t *testing.T) {
},
}
resp, err := s.CreateIndex(ctx, req)
assert.Error(t, merr.CheckRPCCall(resp, err))
assert.NoError(t, merr.CheckRPCCall(resp, err))
})
}
func TestServer_AlterIndex(t *testing.T) {
initStreamingSystem(t)
var (
collID = UniqueID(1)
partID = UniqueID(2)
@ -610,6 +692,12 @@ func TestServer_AlterIndex(t *testing.T) {
},
}, nil).Once()
}
b := broker.NewMockBroker(t)
b.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(nil),
DbName: "test_db",
CollectionName: "test_collection",
}, nil)
s := &Server{
meta: &meta{
@ -669,7 +757,9 @@ func TestServer_AlterIndex(t *testing.T) {
allocator: mock0Allocator,
notifyIndexChan: make(chan UniqueID, 1),
handler: mockHandler,
broker: b,
}
RegisterDDLCallbacks(s)
t.Run("server not available", func(t *testing.T) {
s.stateCode.Store(commonpb.StateCode_Initializing)
@ -1267,6 +1357,8 @@ func TestServer_GetIndexBuildProgress(t *testing.T) {
}
func TestServer_DescribeIndex(t *testing.T) {
initStreamingSystem(t)
var (
collID = UniqueID(1)
partID = UniqueID(2)
@ -1352,6 +1444,13 @@ func TestServer_DescribeIndex(t *testing.T) {
},
},
}
b := broker.NewMockBroker(t)
b.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(nil),
DbName: "test_db",
CollectionName: "test_collection",
}, nil)
s := &Server{
meta: &meta{
catalog: catalog,
@ -1453,6 +1552,7 @@ func TestServer_DescribeIndex(t *testing.T) {
mixCoord: mocks.NewMixCoord(t),
allocator: mock0Allocator,
notifyIndexChan: make(chan UniqueID, 1),
broker: b,
}
segIdx1 := typeutil.NewConcurrentMap[UniqueID, *model.SegmentIndex]()
segIdx1.Insert(indexID, &model.SegmentIndex{
@ -1615,6 +1715,7 @@ func TestServer_DescribeIndex(t *testing.T) {
for id, segment := range segments {
s.meta.segments.SetSegment(id, segment)
}
RegisterDDLCallbacks(s)
t.Run("server not available", func(t *testing.T) {
s.stateCode.Store(commonpb.StateCode_Initializing)
@ -1818,6 +1919,7 @@ func TestServer_ListIndexes(t *testing.T) {
}
func TestServer_GetIndexStatistics(t *testing.T) {
initStreamingSystem(t)
var (
collID = UniqueID(1)
partID = UniqueID(2)
@ -1886,6 +1988,12 @@ func TestServer_GetIndexStatistics(t *testing.T) {
},
},
}
b := broker.NewMockBroker(t)
b.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(nil),
DbName: "test_db",
CollectionName: "test_collection",
}, nil)
s := &Server{
meta: &meta{
catalog: catalog,
@ -1987,6 +2095,7 @@ func TestServer_GetIndexStatistics(t *testing.T) {
mixCoord: mocks.NewMixCoord(t),
allocator: mock0Allocator,
notifyIndexChan: make(chan UniqueID, 1),
broker: b,
}
segIdx1 := typeutil.NewConcurrentMap[UniqueID, *model.SegmentIndex]()
segIdx1.Insert(indexID, &model.SegmentIndex{
@ -2078,6 +2187,7 @@ func TestServer_GetIndexStatistics(t *testing.T) {
for id, segment := range segments {
s.meta.segments.SetSegment(id, segment)
}
RegisterDDLCallbacks(s)
t.Run("server not available", func(t *testing.T) {
s.stateCode.Store(commonpb.StateCode_Initializing)
@ -2116,6 +2226,7 @@ func TestServer_GetIndexStatistics(t *testing.T) {
}
func TestServer_DropIndex(t *testing.T) {
initStreamingSystem(t)
var (
collID = UniqueID(1)
partID = UniqueID(2)
@ -2151,6 +2262,13 @@ func TestServer_DropIndex(t *testing.T) {
mock0Allocator := newMockAllocator(t)
b := broker.NewMockBroker(t)
b.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(nil),
DbName: "test_db",
CollectionName: "test_collection",
}, nil)
s := &Server{
meta: &meta{
catalog: catalog,
@ -2235,6 +2353,7 @@ func TestServer_DropIndex(t *testing.T) {
segments: NewSegmentsInfo(),
},
broker: b,
allocator: mock0Allocator,
notifyIndexChan: make(chan UniqueID, 1),
}
@ -2265,18 +2384,20 @@ func TestServer_DropIndex(t *testing.T) {
assert.ErrorIs(t, merr.Error(resp), merr.ErrServiceNotReady)
})
RegisterDDLCallbacks(s)
s.stateCode.Store(commonpb.StateCode_Healthy)
t.Run("drop fail", func(t *testing.T) {
catalog := catalogmocks.NewDataCoordCatalog(t)
catalog.On("AlterIndexes",
mock.Anything,
mock.Anything,
).Return(errors.New("fail"))
catalog.EXPECT().AlterIndexes(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, indexes []*model.Index) error {
if rand.Intn(3) == 0 {
return errors.New("fail")
}
return nil
}).Maybe()
s.meta.indexMeta.catalog = catalog
resp, err := s.DropIndex(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetErrorCode())
assert.NoError(t, merr.CheckRPCCall(resp, err))
})
t.Run("drop one index", func(t *testing.T) {
@ -2672,6 +2793,8 @@ func TestValidateIndexParams(t *testing.T) {
}
func TestJsonIndex(t *testing.T) {
initStreamingSystem(t)
collID := UniqueID(1)
catalog := catalogmocks.NewDataCoordCatalog(t)
catalog.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(nil).Maybe()
@ -2721,6 +2844,7 @@ func TestJsonIndex(t *testing.T) {
notifyIndexChan: make(chan UniqueID, 1),
broker: broker.NewCoordinatorBroker(b),
}
RegisterDDLCallbacks(s)
s.stateCode.Store(commonpb.StateCode_Healthy)
req := &indexpb.CreateIndexRequest{

View File

@ -335,6 +335,7 @@ func (s *Server) initDataCoord() error {
s.serverLoopCtx, s.serverLoopCancel = context.WithCancel(s.ctx)
RegisterDDLCallbacks(s)
log.Info("init datacoord done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", s.address))
s.initMessageCallback()
@ -344,16 +345,6 @@ func (s *Server) initDataCoord() error {
// initMessageCallback initializes the message callback.
// TODO: we should build a ddl framework to handle the message ack callback for ddl messages
func (s *Server) initMessageCallback() {
registry.RegisterDropPartitionV1AckCallback(func(ctx context.Context, result message.BroadcastResultDropPartitionMessageV1) error {
partitionID := result.Message.Header().PartitionId
for _, vchannel := range result.GetVChannelsWithoutControlChannel() {
if err := s.NotifyDropPartition(ctx, vchannel, []int64{partitionID}); err != nil {
return err
}
}
return nil
})
registry.RegisterImportV1AckCallback(func(ctx context.Context, result message.BroadcastResultImportMessageV1) error {
body := result.Message.MustBody()
vchannels := result.GetVChannelsWithoutControlChannel()

View File

@ -60,7 +60,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
@ -78,14 +77,8 @@ const maxOperationsPerTxn = int64(64)
func TestMain(m *testing.M) {
paramtable.Init()
rand.Seed(time.Now().UnixNano())
parameters := []string{"tikv", "etcd"}
var code int
for _, v := range parameters {
paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, v)
code = m.Run()
}
code := m.Run()
os.Exit(code)
}
@ -1194,12 +1187,11 @@ func TestGetRecoveryInfo(t *testing.T) {
})
assert.NoError(t, err)
indexReq := &indexpb.CreateIndexRequest{
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err = svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 0, false)
IndexID: rand.Int63n(1000),
})
assert.NoError(t, err)
seg1 := createSegment(0, 0, 0, 100, 10, "vchan1", commonpb.SegmentState_Flushed)
@ -1413,12 +1405,11 @@ func TestGetRecoveryInfo(t *testing.T) {
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment))
assert.NoError(t, err)
indexReq := &indexpb.CreateIndexRequest{
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err = svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 0, false)
IndexID: rand.Int63n(1000),
})
assert.NoError(t, err)
err = svr.meta.indexMeta.AddSegmentIndex(context.TODO(), &model.SegmentIndex{
SegmentID: segment.ID,
@ -1573,12 +1564,12 @@ func TestGetRecoveryInfo(t *testing.T) {
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg5))
assert.NoError(t, err)
indexReq := &indexpb.CreateIndexRequest{
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
IndexID: rand.Int63n(1000),
IndexName: "_default_idx_2",
}
_, err = svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 0, false)
})
assert.NoError(t, err)
svr.meta.indexMeta.updateSegmentIndex(&model.SegmentIndex{
SegmentID: seg4.ID,
@ -2754,7 +2745,11 @@ func TestServer_InitMessageCallback(t *testing.T) {
mb := mock_balancer.NewMockBalancer(t)
mb.EXPECT().GetLatestChannelAssignment().Return(&balancer.WatchChannelAssignmentsCallbackParam{}, nil).Maybe()
balance.ResetBalancer()
mb.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
}).Maybe()
snmanager.ResetStreamingNodeManager()
balance.Register(mb)
server := &Server{
@ -2772,28 +2767,6 @@ func TestServer_InitMessageCallback(t *testing.T) {
// Test initMessageCallback
server.initMessageCallback()
// Test DropPartition message callback
dropPartitionMsg := message.NewDropPartitionMessageBuilderV1().
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: 1,
PartitionId: 1,
}).
WithBody(&msgpb.DropPartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropPartition,
},
}).
WithBroadcast([]string{"test_channel"}, message.NewImportJobIDResourceKey(1)).
MustBuildBroadcast()
err := registry.CallMessageAckCallback(ctx, dropPartitionMsg, map[string]*message.AppendResult{
"test_channel": {
MessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: walimplstest.NewTestMessageID(1),
TimeTick: 1,
},
})
assert.Error(t, err) // server not healthy
// Test Import message check callback
resourceKey := message.NewImportJobIDResourceKey(1)
msg, err := message.NewImportMessageBuilderV1().

View File

@ -1395,6 +1395,7 @@ func (s *Server) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.
}
// WatchChannels notifies DataCoord to watch vchannels of a collection.
// Deprecated: Redundant design by now, remove it in future.
func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
@ -1412,6 +1413,7 @@ func (s *Server) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq
}, nil
}
for _, channelName := range req.GetChannelNames() {
// TODO: redundant channel mark by now, remove it in future.
if err := s.meta.catalog.MarkChannelAdded(ctx, channelName); err != nil {
// TODO: add background task to periodically cleanup the orphaned channel add marks.
log.Error("failed to mark channel added", zap.Error(err))

View File

@ -3,6 +3,7 @@ package datacoord
import (
"context"
"fmt"
"math/rand"
"testing"
"time"
@ -36,9 +37,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/kv"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
@ -83,24 +82,6 @@ func TestServerSuite(t *testing.T) {
suite.Run(t, new(ServerSuite))
}
func genMsg(msgType commonpb.MsgType, ch string, t Timestamp, sourceID int64) *msgstream.DataNodeTtMsg {
return &msgstream.DataNodeTtMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
DataNodeTtMsg: &msgpb.DataNodeTtMsg{
Base: &commonpb.MsgBase{
MsgType: msgType,
Timestamp: t,
SourceID: sourceID,
},
ChannelName: ch,
Timestamp: t,
SegmentsStats: []*commonpb.SegmentStats{{SegmentID: 2, NumRows: 100}},
},
}
}
func (s *ServerSuite) TestGetFlushState_ByFlushTs() {
s.mockMixCoord.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, req *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
if req.CollectionID == 0 {
@ -1023,11 +1004,11 @@ func TestGetRecoveryInfoV2(t *testing.T) {
})
assert.NoError(t, err)
indexReq := &indexpb.CreateIndexRequest{
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err = svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 0, false)
IndexID: rand.Int63n(1000),
})
assert.NoError(t, err)
seg1 := createSegment(0, 0, 0, 100, 10, "vchan1", commonpb.SegmentState_Flushed)
@ -1235,12 +1216,11 @@ func TestGetRecoveryInfoV2(t *testing.T) {
err := svr.meta.AddSegment(context.TODO(), NewSegmentInfo(segment))
assert.NoError(t, err)
indexReq := &indexpb.CreateIndexRequest{
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
}
_, err = svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 0, false)
IndexID: rand.Int63n(1000),
})
assert.NoError(t, err)
err = svr.meta.indexMeta.AddSegmentIndex(context.TODO(), &model.SegmentIndex{
SegmentID: segment.ID,
@ -1389,12 +1369,12 @@ func TestGetRecoveryInfoV2(t *testing.T) {
assert.NoError(t, err)
err = svr.meta.AddSegment(context.TODO(), NewSegmentInfo(seg5))
assert.NoError(t, err)
indexReq := &indexpb.CreateIndexRequest{
err = svr.meta.indexMeta.CreateIndex(context.TODO(), &model.Index{
CollectionID: 0,
FieldID: 2,
IndexID: rand.Int63n(1000),
IndexName: "_default_idx_2",
}
_, err = svr.meta.indexMeta.CreateIndex(context.TODO(), indexReq, 0, false)
})
assert.NoError(t, err)
svr.meta.indexMeta.updateSegmentIndex(&model.SegmentIndex{
SegmentID: seg4.ID,

View File

@ -147,7 +147,7 @@ func TestStreamingBroadcast(t *testing.T) {
CollectionID: 1,
CollectionName: collectionName,
}).
WithBroadcast(vChannels, message.NewCollectionNameResourceKey(collectionName)).
WithBroadcast(vChannels, message.NewExclusiveCollectionNameResourceKey("db", collectionName)).
BuildBroadcast()
resp, err := streaming.WAL().Broadcast().Append(context.Background(), msg)

View File

@ -126,6 +126,7 @@ type DataCoordCatalog interface {
SaveDroppedSegmentsInBatch(ctx context.Context, segments []*datapb.SegmentInfo) error
DropSegment(ctx context.Context, segment *datapb.SegmentInfo) error
// TODO: From MarkChannelAdded to DropChannel, it's totally a redundant design by now, remove it in future.
MarkChannelAdded(ctx context.Context, channel string) error
MarkChannelDeleted(ctx context.Context, channel string) error
ShouldDropChannel(ctx context.Context, channel string) bool

View File

@ -171,8 +171,8 @@ func (kc *Catalog) ListDatabases(ctx context.Context, ts typeutil.Timestamp) ([]
}
func (kc *Catalog) CreateCollection(ctx context.Context, coll *model.Collection, ts typeutil.Timestamp) error {
if coll.State != pb.CollectionState_CollectionCreating {
return fmt.Errorf("cannot create collection with state: %s, collection: %s", coll.State.String(), coll.Name)
if coll.State != pb.CollectionState_CollectionCreated {
return fmt.Errorf("collection state should be created, collection name: %s, collection id: %d, state: %s", coll.Name, coll.CollectionID, coll.State)
}
k1 := BuildCollectionKey(coll.DBID, coll.CollectionID)

View File

@ -1058,7 +1058,7 @@ func TestCatalog_AlterCollection(t *testing.T) {
kc := NewCatalog(nil, snapshot).(*Catalog)
ctx := context.Background()
var collectionID int64 = 1
oldC := &model.Collection{CollectionID: collectionID, State: pb.CollectionState_CollectionCreating}
oldC := &model.Collection{CollectionID: collectionID, State: pb.CollectionState_CollectionCreated}
newC := &model.Collection{CollectionID: collectionID, State: pb.CollectionState_CollectionCreated, UpdateTimestamp: rand.Uint64()}
err := kc.AlterCollection(ctx, oldC, newC, metastore.MODIFY, 0, true)
assert.NoError(t, err)
@ -1077,7 +1077,7 @@ func TestCatalog_AlterCollection(t *testing.T) {
kc := NewCatalog(nil, nil)
ctx := context.Background()
var collectionID int64 = 1
oldC := &model.Collection{TenantID: "1", CollectionID: collectionID, State: pb.CollectionState_CollectionCreating}
oldC := &model.Collection{TenantID: "1", CollectionID: collectionID, State: pb.CollectionState_CollectionCreated}
newC := &model.Collection{TenantID: "2", CollectionID: collectionID, State: pb.CollectionState_CollectionCreated}
err := kc.AlterCollection(ctx, oldC, newC, metastore.MODIFY, 0, true)
assert.Error(t, err)
@ -1267,7 +1267,7 @@ func TestCatalog_CreateCollection(t *testing.T) {
mockSnapshot := newMockSnapshot(t, withMockSave(errors.New("error mock Save")))
kc := NewCatalog(nil, mockSnapshot)
ctx := context.Background()
coll := &model.Collection{State: pb.CollectionState_CollectionCreating}
coll := &model.Collection{State: pb.CollectionState_CollectionCreated}
err := kc.CreateCollection(ctx, coll, 100)
assert.Error(t, err)
})
@ -1280,7 +1280,7 @@ func TestCatalog_CreateCollection(t *testing.T) {
Partitions: []*model.Partition{
{PartitionName: "test"},
},
State: pb.CollectionState_CollectionCreating,
State: pb.CollectionState_CollectionCreated,
}
err := kc.CreateCollection(ctx, coll, 100)
assert.Error(t, err)
@ -1294,7 +1294,7 @@ func TestCatalog_CreateCollection(t *testing.T) {
Partitions: []*model.Partition{
{PartitionName: "test"},
},
State: pb.CollectionState_CollectionCreating,
State: pb.CollectionState_CollectionCreated,
}
err := kc.CreateCollection(ctx, coll, 100)
assert.NoError(t, err)
@ -1349,7 +1349,7 @@ func TestCatalog_CreateCollection(t *testing.T) {
OutputFieldNames: []string{"sparse"},
},
},
State: pb.CollectionState_CollectionCreating,
State: pb.CollectionState_CollectionCreated,
}
err := kc.CreateCollection(ctx, coll, 100)
assert.NoError(t, err)

View File

@ -1,7 +1,6 @@
package model
import (
"github.com/milvus-io/milvus/pkg/v2/common"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
)
@ -9,7 +8,6 @@ type Partition struct {
PartitionID int64
PartitionName string
PartitionCreatedTimestamp uint64
Extra map[string]string // deprecated.
CollectionID int64
State pb.PartitionState
}
@ -23,7 +21,6 @@ func (p *Partition) Clone() *Partition {
PartitionID: p.PartitionID,
PartitionName: p.PartitionName,
PartitionCreatedTimestamp: p.PartitionCreatedTimestamp,
Extra: common.CloneStr2Str(p.Extra),
CollectionID: p.CollectionID,
State: p.State,
}

View File

@ -0,0 +1,100 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_tombstone
import (
tombstone "github.com/milvus-io/milvus/internal/rootcoord/tombstone"
mock "github.com/stretchr/testify/mock"
)
// MockTombstoneSweeper is an autogenerated mock type for the TombstoneSweeper type
type MockTombstoneSweeper struct {
mock.Mock
}
type MockTombstoneSweeper_Expecter struct {
mock *mock.Mock
}
func (_m *MockTombstoneSweeper) EXPECT() *MockTombstoneSweeper_Expecter {
return &MockTombstoneSweeper_Expecter{mock: &_m.Mock}
}
// AddTombstone provides a mock function with given fields: _a0
func (_m *MockTombstoneSweeper) AddTombstone(_a0 tombstone.Tombstone) {
_m.Called(_a0)
}
// MockTombstoneSweeper_AddTombstone_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTombstone'
type MockTombstoneSweeper_AddTombstone_Call struct {
*mock.Call
}
// AddTombstone is a helper method to define mock.On call
// - _a0 tombstone.Tombstone
func (_e *MockTombstoneSweeper_Expecter) AddTombstone(_a0 interface{}) *MockTombstoneSweeper_AddTombstone_Call {
return &MockTombstoneSweeper_AddTombstone_Call{Call: _e.mock.On("AddTombstone", _a0)}
}
func (_c *MockTombstoneSweeper_AddTombstone_Call) Run(run func(_a0 tombstone.Tombstone)) *MockTombstoneSweeper_AddTombstone_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(tombstone.Tombstone))
})
return _c
}
func (_c *MockTombstoneSweeper_AddTombstone_Call) Return() *MockTombstoneSweeper_AddTombstone_Call {
_c.Call.Return()
return _c
}
func (_c *MockTombstoneSweeper_AddTombstone_Call) RunAndReturn(run func(tombstone.Tombstone)) *MockTombstoneSweeper_AddTombstone_Call {
_c.Run(run)
return _c
}
// Close provides a mock function with no fields
func (_m *MockTombstoneSweeper) Close() {
_m.Called()
}
// MockTombstoneSweeper_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockTombstoneSweeper_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockTombstoneSweeper_Expecter) Close() *MockTombstoneSweeper_Close_Call {
return &MockTombstoneSweeper_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockTombstoneSweeper_Close_Call) Run(run func()) *MockTombstoneSweeper_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockTombstoneSweeper_Close_Call) Return() *MockTombstoneSweeper_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockTombstoneSweeper_Close_Call) RunAndReturn(run func()) *MockTombstoneSweeper_Close_Call {
_c.Run(run)
return _c
}
// NewMockTombstoneSweeper creates a new instance of MockTombstoneSweeper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockTombstoneSweeper(t interface {
mock.TestingT
Cleanup(func())
}) *MockTombstoneSweeper {
mock := &MockTombstoneSweeper{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -33,6 +33,65 @@ func (_m *MockBalancer) EXPECT() *MockBalancer_Expecter {
return &MockBalancer_Expecter{mock: &_m.Mock}
}
// AllocVirtualChannels provides a mock function with given fields: ctx, param
func (_m *MockBalancer) AllocVirtualChannels(ctx context.Context, param balancer.AllocVChannelParam) ([]string, error) {
ret := _m.Called(ctx, param)
if len(ret) == 0 {
panic("no return value specified for AllocVirtualChannels")
}
var r0 []string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, balancer.AllocVChannelParam) ([]string, error)); ok {
return rf(ctx, param)
}
if rf, ok := ret.Get(0).(func(context.Context, balancer.AllocVChannelParam) []string); ok {
r0 = rf(ctx, param)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, balancer.AllocVChannelParam) error); ok {
r1 = rf(ctx, param)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockBalancer_AllocVirtualChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllocVirtualChannels'
type MockBalancer_AllocVirtualChannels_Call struct {
*mock.Call
}
// AllocVirtualChannels is a helper method to define mock.On call
// - ctx context.Context
// - param balancer.AllocVChannelParam
func (_e *MockBalancer_Expecter) AllocVirtualChannels(ctx interface{}, param interface{}) *MockBalancer_AllocVirtualChannels_Call {
return &MockBalancer_AllocVirtualChannels_Call{Call: _e.mock.On("AllocVirtualChannels", ctx, param)}
}
func (_c *MockBalancer_AllocVirtualChannels_Call) Run(run func(ctx context.Context, param balancer.AllocVChannelParam)) *MockBalancer_AllocVirtualChannels_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(balancer.AllocVChannelParam))
})
return _c
}
func (_c *MockBalancer_AllocVirtualChannels_Call) Return(_a0 []string, _a1 error) *MockBalancer_AllocVirtualChannels_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockBalancer_AllocVirtualChannels_Call) RunAndReturn(run func(context.Context, balancer.AllocVChannelParam) ([]string, error)) *MockBalancer_AllocVirtualChannels_Call {
_c.Call.Return(run)
return _c
}
// Close provides a mock function with no fields
func (_m *MockBalancer) Close() {
_m.Called()

View File

@ -2981,8 +2981,8 @@ type hybridSearchRequestExprLogger struct {
*milvuspb.HybridSearchRequest
}
// String implements Stringer interface for lazy logging.
func (l *hybridSearchRequestExprLogger) String() string {
// Key implements Stringer interface for lazy logging.
func (l *hybridSearchRequestExprLogger) Key() string {
builder := &strings.Builder{}
for idx, subReq := range l.Requests {

View File

@ -16,13 +16,23 @@
package querycoordv2
import "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
import (
"context"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
// RegisterDDLCallbacks registers the ddl callbacks.
func RegisterDDLCallbacks(s *Server) {
ddlCallback := &DDLCallbacks{
Server: s,
}
ddlCallback.registerLoadConfigCallbacks()
ddlCallback.registerResourceGroupCallbacks()
}
@ -30,11 +40,29 @@ type DDLCallbacks struct {
*Server
}
// registerLoadConfigCallbacks registers the load config callbacks.
func (c *DDLCallbacks) registerLoadConfigCallbacks() {
registry.RegisterAlterLoadConfigV2AckCallback(c.alterLoadConfigV2AckCallback)
registry.RegisterDropLoadConfigV2AckCallback(c.dropLoadConfigV2AckCallback)
}
func (c *DDLCallbacks) registerResourceGroupCallbacks() {
registry.RegisterAlterResourceGroupV2AckCallback(c.alterResourceGroupV2AckCallback)
registry.RegisterDropResourceGroupV2AckCallback(c.dropResourceGroupV2AckCallback)
}
func (c *DDLCallbacks) RegisterDDLCallbacks() {
c.registerResourceGroupCallbacks()
// startBroadcastWithCollectionIDLock starts a broadcast with collection id lock.
func (c *Server) startBroadcastWithCollectionIDLock(ctx context.Context, collectionID int64) (broadcaster.BroadcastAPI, error) {
coll, err := c.broker.DescribeCollection(ctx, collectionID)
if err != nil {
return nil, err
}
broadcaster, err := broadcast.StartBroadcastWithResourceKeys(ctx,
message.NewSharedDBNameResourceKey(coll.GetDbName()),
message.NewExclusiveCollectionNameResourceKey(coll.GetDbName(), coll.GetCollectionName()),
)
if err != nil {
return nil, errors.Wrap(err, "failed to start broadcast with collection lock")
}
return broadcaster, nil
}

View File

@ -0,0 +1,37 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
// alterLoadConfigV2AckCallback is called when the put load config message is acknowledged
func (s *Server) alterLoadConfigV2AckCallback(ctx context.Context, result message.BroadcastResultAlterLoadConfigMessageV2) error {
// currently, we only sent the put load config message to the control channel
// TODO: after we support query view in 3.0, we should broadcast the put load config message to all vchannels.
job := job.NewLoadCollectionJob(ctx, result, s.dist, s.meta, s.broker, s.targetMgr, s.targetObserver, s.collectionObserver, s.nodeMgr)
if err := job.Execute(); err != nil {
return err
}
meta.GlobalFailedLoadCache.Remove(result.Message.Header().GetCollectionId())
return nil
}

View File

@ -0,0 +1,126 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"fmt"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
)
// broadcastAlterLoadConfigCollectionV2ForLoadCollection is called when the load collection request is received.
func (s *Server) broadcastAlterLoadConfigCollectionV2ForLoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) error {
broadcaster, err := s.startBroadcastWithCollectionIDLock(ctx, req.GetCollectionID())
if err != nil {
return err
}
defer broadcaster.Close()
// double check if the collection is already dropped
coll, err := s.broker.DescribeCollection(ctx, req.GetCollectionID())
if err != nil {
return err
}
partitionIDs, err := s.broker.GetPartitions(ctx, coll.CollectionID)
if err != nil {
return err
}
// if user specified the replica number in load request, load config changes won't be apply to the collection automatically
userSpecifiedReplicaMode := req.GetReplicaNumber() > 0
replicaNumber, resourceGroups, err := s.getDefaultResourceGroupsAndReplicaNumber(ctx, req.GetReplicaNumber(), req.GetResourceGroups(), req.GetCollectionID())
if err != nil {
return err
}
currentLoadConfig := s.getCurrentLoadConfig(ctx, req.GetCollectionID())
// only check node number when the collection is not loaded
expectedReplicasNumber, err := utils.AssignReplica(ctx, s.meta, resourceGroups, replicaNumber, currentLoadConfig.Collection == nil)
if err != nil {
return err
}
alterLoadConfigReq := &job.AlterLoadConfigRequest{
Meta: s.meta,
CollectionInfo: coll,
Current: currentLoadConfig,
Expected: job.ExpectedLoadConfig{
ExpectedPartitionIDs: partitionIDs,
ExpectedReplicaNumber: expectedReplicasNumber,
ExpectedFieldIndexID: req.GetFieldIndexID(),
ExpectedLoadFields: req.GetLoadFields(),
ExpectedPriority: req.GetPriority(),
ExpectedUserSpecifiedReplicaMode: userSpecifiedReplicaMode,
},
}
msg, err := job.GenerateAlterLoadConfigMessage(ctx, alterLoadConfigReq)
if err != nil {
return err
}
_, err = broadcaster.Broadcast(ctx, msg)
return err
}
// getDefaultResourceGroupsAndReplicaNumber gets the default resource groups and replica number for the collection.
func (s *Server) getDefaultResourceGroupsAndReplicaNumber(ctx context.Context, replicaNumber int32, resourceGroups []string, collectionID int64) (int32, []string, error) {
// so only both replica and resource groups didn't set in request, it will turn to use the configured load info
if replicaNumber <= 0 && len(resourceGroups) == 0 {
// when replica number or resource groups is not set, use pre-defined load config
rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, collectionID)
if err != nil {
log.Warn("failed to get pre-defined load info", zap.Error(err))
} else {
replicaNumber = int32(replicas)
resourceGroups = rgs
}
}
// to be compatible with old sdk, which set replica=1 if replica is not specified
if replicaNumber <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1")
replicaNumber = 1
}
if len(resourceGroups) == 0 {
log.Info(fmt.Sprintf("request doesn't indicate the resource groups, set it to %s", meta.DefaultResourceGroupName))
resourceGroups = []string{meta.DefaultResourceGroupName}
}
return replicaNumber, resourceGroups, nil
}
func (s *Server) getCurrentLoadConfig(ctx context.Context, collectionID int64) job.CurrentLoadConfig {
partitionList := s.meta.CollectionManager.GetPartitionsByCollection(ctx, collectionID)
loadedPartitions := make(map[int64]*meta.Partition)
for _, partitioin := range partitionList {
loadedPartitions[partitioin.PartitionID] = partitioin
}
replicas := s.meta.ReplicaManager.GetByCollection(ctx, collectionID)
loadedReplicas := make(map[int64]*meta.Replica)
for _, replica := range replicas {
loadedReplicas[replica.GetID()] = replica
}
return job.CurrentLoadConfig{
Collection: s.meta.CollectionManager.GetCollection(ctx, collectionID),
Partitions: loadedPartitions,
Replicas: loadedReplicas,
}
}

View File

@ -0,0 +1,80 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func (s *Server) broadcastAlterLoadConfigCollectionV2ForLoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) error {
broadcaster, err := s.startBroadcastWithCollectionIDLock(ctx, req.GetCollectionID())
if err != nil {
return err
}
defer broadcaster.Close()
// double check if the collection is already dropped
coll, err := s.broker.DescribeCollection(ctx, req.GetCollectionID())
if err != nil {
return err
}
userSpecifiedReplicaMode := req.GetReplicaNumber() > 0
replicaNumber, resourceGroups, err := s.getDefaultResourceGroupsAndReplicaNumber(ctx, req.GetReplicaNumber(), req.GetResourceGroups(), req.GetCollectionID())
if err != nil {
return err
}
expectedReplicasNumber, err := utils.AssignReplica(ctx, s.meta, resourceGroups, replicaNumber, true)
if err != nil {
return err
}
currentLoadConfig := s.getCurrentLoadConfig(ctx, req.GetCollectionID())
partitionIDsSet := typeutil.NewSet(currentLoadConfig.GetPartitionIDs()...)
// add new incoming partitionIDs.
for _, partition := range req.PartitionIDs {
partitionIDsSet.Insert(partition)
}
alterLoadConfigReq := &job.AlterLoadConfigRequest{
Meta: s.meta,
CollectionInfo: coll,
Current: s.getCurrentLoadConfig(ctx, req.GetCollectionID()),
Expected: job.ExpectedLoadConfig{
ExpectedPartitionIDs: partitionIDsSet.Collect(),
ExpectedReplicaNumber: expectedReplicasNumber,
ExpectedFieldIndexID: req.GetFieldIndexID(),
ExpectedLoadFields: req.GetLoadFields(),
ExpectedPriority: req.GetPriority(),
ExpectedUserSpecifiedReplicaMode: userSpecifiedReplicaMode,
},
}
if err := alterLoadConfigReq.CheckIfLoadPartitionsExecutable(); err != nil {
return err
}
msg, err := job.GenerateAlterLoadConfigMessage(ctx, alterLoadConfigReq)
if err != nil {
return err
}
_, err = broadcaster.Broadcast(ctx, msg)
return err
}

View File

@ -0,0 +1,95 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func (s *Server) broadcastAlterLoadConfigCollectionV2ForReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (collectionReleased bool, err error) {
broadcaster, err := s.startBroadcastWithCollectionIDLock(ctx, req.GetCollectionID())
if err != nil {
return false, err
}
defer broadcaster.Close()
// double check if the collection is already dropped
coll, err := s.broker.DescribeCollection(ctx, req.GetCollectionID())
if err != nil {
return false, err
}
currentLoadConfig := s.getCurrentLoadConfig(ctx, req.GetCollectionID())
if currentLoadConfig.Collection == nil {
// collection is not loaded, return success directly.
return true, nil
}
// remove the partitions that should be released.
partitionIDsSet := typeutil.NewSet(currentLoadConfig.GetPartitionIDs()...)
previousLength := len(partitionIDsSet)
for _, partitionID := range req.PartitionIDs {
partitionIDsSet.Remove(partitionID)
}
// no partition to be released, return success directly.
if len(partitionIDsSet) == previousLength {
return false, job.ErrIgnoredAlterLoadConfig
}
var msg message.BroadcastMutableMessage
if len(partitionIDsSet) == 0 {
// all partitions are released, release the collection directly.
msg = message.NewDropLoadConfigMessageBuilderV2().
WithHeader(&message.DropLoadConfigMessageHeader{
DbId: coll.DbId,
CollectionId: coll.CollectionID,
}).
WithBody(&message.DropLoadConfigMessageBody{}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}). // TODO: after we support query view in 3.0, we should broadcast the drop load config message to all vchannels.
MustBuildBroadcast()
collectionReleased = true
} else {
// only some partitions are released, alter the load config.
alterLoadConfigReq := &job.AlterLoadConfigRequest{
Meta: s.meta,
CollectionInfo: coll,
Current: s.getCurrentLoadConfig(ctx, req.GetCollectionID()),
Expected: job.ExpectedLoadConfig{
ExpectedPartitionIDs: partitionIDsSet.Collect(),
ExpectedReplicaNumber: currentLoadConfig.GetReplicaNumber(),
ExpectedFieldIndexID: currentLoadConfig.GetFieldIndexID(),
ExpectedLoadFields: currentLoadConfig.GetLoadFields(),
ExpectedPriority: commonpb.LoadPriority_HIGH,
ExpectedUserSpecifiedReplicaMode: currentLoadConfig.GetUserSpecifiedReplicaMode(),
},
}
if msg, err = job.GenerateAlterLoadConfigMessage(ctx, alterLoadConfigReq); err != nil {
return false, err
}
collectionReleased = false
}
_, err = broadcaster.Broadcast(ctx, msg)
return collectionReleased, err
}

View File

@ -0,0 +1,79 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
var errReleaseCollectionNotLoaded = errors.New("release collection not loaded")
// broadcastDropLoadConfigCollectionV2ForReleaseCollection broadcasts the drop load config message for release collection.
func (s *Server) broadcastDropLoadConfigCollectionV2ForReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) error {
broadcaster, err := s.startBroadcastWithCollectionIDLock(ctx, req.GetCollectionID())
if err != nil {
return err
}
defer broadcaster.Close()
// double check if the collection is already dropped.
coll, err := s.broker.DescribeCollection(ctx, req.GetCollectionID())
if err != nil {
return err
}
if !s.meta.CollectionManager.Exist(ctx, req.GetCollectionID()) {
return errReleaseCollectionNotLoaded
}
msg := message.NewDropLoadConfigMessageBuilderV2().
WithHeader(&message.DropLoadConfigMessageHeader{
DbId: coll.GetDbId(),
CollectionId: coll.GetCollectionID(),
}).
WithBody(&message.DropLoadConfigMessageBody{}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}). // TODO: after we support query view in 3.0, we should broadcast the drop load config message to all vchannels.
MustBuildBroadcast()
_, err = broadcaster.Broadcast(ctx, msg)
return err
}
func (s *Server) dropLoadConfigV2AckCallback(ctx context.Context, result message.BroadcastResultDropLoadConfigMessageV2) error {
releaseJob := job.NewReleaseCollectionJob(ctx,
result,
s.dist,
s.meta,
s.broker,
s.targetMgr,
s.targetObserver,
s.checkerController,
s.proxyClientManager,
)
if err := releaseJob.Execute(); err != nil {
return err
}
meta.GlobalFailedLoadCache.Remove(result.Message.Header().GetCollectionId())
return nil
}

View File

@ -0,0 +1,917 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querycoordv2
import (
"context"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/rgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/querycoordv2/job"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func (suite *ServiceSuite) TestDDLCallbacksLoadCollectionInfo() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
// Test load collection
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
// Load with 1 replica
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
// It will be set to 1
// ReplicaNumber: 1,
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
// Test load again
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
// Test load partition while collection exists
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
// Load with 1 replica
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 1,
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
// Test load existed collection with different replica number
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: 3,
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
cfg := &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 0,
},
Limits: &rgpb.ResourceGroupLimit{
NodeNum: 0,
},
}
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", cfg)
// Load with 3 replica on 1 rg
req := &querypb.LoadCollectionRequest{
CollectionID: 1001,
ReplicaNumber: 3,
ResourceGroups: []string{"rg1"},
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().ErrorIs(merr.CheckRPCCall(resp, err), merr.ErrResourceGroupNodeNotEnough)
// Load with 3 replica on 3 rg
req = &querypb.LoadCollectionRequest{
CollectionID: 1001,
ReplicaNumber: 3,
ResourceGroups: []string{"rg1", "rg2", "rg3"},
}
resp, err = suite.server.LoadCollection(ctx, req)
suite.Require().ErrorIs(merr.CheckRPCCall(resp, err), merr.ErrResourceGroupNodeNotEnough)
}
func (suite *ServiceSuite) TestDDLCallbacksLoadCollectionWithReplicas() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
// Test load collection
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
// Load with 3 replica
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: int32(len(suite.nodes) + 1),
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().ErrorIs(merr.CheckRPCCall(resp, err), merr.ErrResourceGroupNodeNotEnough)
}
}
func (suite *ServiceSuite) TestDDLCallbacksLoadCollectionWithLoadFields() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
suite.Run("init_load", func() {
// Test load collection
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
// Load with 1 replica
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
LoadFields: []int64{100, 101, 102},
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
})
suite.Run("load_again_same_fields", func() {
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
LoadFields: []int64{102, 101, 100}, // field id order shall not matter
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
})
suite.Run("load_again_diff_fields", func() {
// Test load existed collection with different load fields
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
LoadFields: []int64{100, 101},
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
})
suite.Run("load_from_legacy_proxy", func() {
// Test load again with legacy proxy
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
Schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100},
{FieldID: 101},
{FieldID: 102},
},
},
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
})
}
func (suite *ServiceSuite) TestDDLCallbacksLoadPartition() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
// Test load partition
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
// Load with 1 replica
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 1,
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
// Test load partition again
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
// Load with 1 replica
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
// ReplicaNumber: 1,
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
// Test load partition with different replica number
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 3,
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().ErrorIs(merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid)
}
// Test load partition with more partition
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 200),
ReplicaNumber: 1,
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
// Test load collection while partitions exists
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: 1,
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
cfg := &rgpb.ResourceGroupConfig{
Requests: &rgpb.ResourceGroupLimit{
NodeNum: 1,
},
Limits: &rgpb.ResourceGroupLimit{
NodeNum: 1,
},
}
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg1", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg2", cfg)
suite.meta.ResourceManager.AddResourceGroup(ctx, "rg3", cfg)
// test load 3 replica in 1 rg, should pass rg check
req := &querypb.LoadPartitionsRequest{
CollectionID: 999,
PartitionIDs: []int64{888},
ReplicaNumber: 3,
ResourceGroups: []string{"rg1"},
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().ErrorIs(merr.CheckRPCCall(resp, err), merr.ErrResourceGroupNodeNotEnough)
// test load 3 replica in 3 rg, should pass rg check
req = &querypb.LoadPartitionsRequest{
CollectionID: 999,
PartitionIDs: []int64{888},
ReplicaNumber: 3,
ResourceGroups: []string{"rg1", "rg2", "rg3"},
}
resp, err = suite.server.LoadPartitions(ctx, req)
suite.Require().ErrorIs(merr.CheckRPCCall(resp, err), merr.ErrResourceGroupNodeNotEnough)
}
func (suite *ServiceSuite) TestLoadPartitionWithLoadFields() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
suite.Run("init_load", func() {
// Test load partition
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
// Load with 1 replica
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 1,
LoadFields: []int64{100, 101, 102},
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.EqualValues(1, suite.meta.GetReplicaNumber(ctx, collection))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
})
suite.Run("load_with_same_load_fields", func() {
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
// Load with 1 replica
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 1,
LoadFields: []int64{102, 101, 100},
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
})
suite.Run("load_with_diff_load_fields", func() {
// Test load partition with different load fields
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
LoadFields: []int64{100, 101},
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
})
suite.Run("load_legacy_proxy", func() {
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
// Load with 1 replica
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 1,
Schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100},
{FieldID: 101},
{FieldID: 102},
},
},
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
}
})
}
func (suite *ServiceSuite) TestDynamicLoad() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
collection := suite.collections[0]
p0, p1, p2 := suite.partitions[collection][0], suite.partitions[collection][1], suite.partitions[collection][2]
newLoadPartJob := func(partitions ...int64) *querypb.LoadPartitionsRequest {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: partitions,
ReplicaNumber: 1,
}
return req
}
newLoadColJob := func() *querypb.LoadCollectionRequest {
return &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: 1,
}
}
// loaded: none
// action: load p0, p1, p2
// expect: p0, p1, p2 loaded
req := newLoadPartJob(p0, p1, p2)
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(ctx, collection, p0, p1, p2)
// loaded: p0, p1, p2
// action: load p0, p1, p2
// expect: do nothing, p0, p1, p2 loaded
req = newLoadPartJob(p0, p1, p2)
resp, err = suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertPartitionLoaded(ctx, collection)
// loaded: p0, p1
// action: load p2
// expect: p0, p1, p2 loaded
suite.releaseAll()
req = newLoadPartJob(p0, p1)
resp, err = suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(ctx, collection, p0, p1)
req = newLoadPartJob(p2)
resp, err = suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(ctx, collection, p2)
// loaded: p0, p1
// action: load p1, p2
// expect: p0, p1, p2 loaded
suite.releaseAll()
req = newLoadPartJob(p0, p1)
resp, err = suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(ctx, collection, p0, p1)
req = newLoadPartJob(p1, p2)
resp, err = suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(ctx, collection, p2)
// loaded: p0, p1
// action: load col
// expect: col loaded
suite.releaseAll()
req = newLoadPartJob(p0, p1)
resp, err = suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(ctx, collection, p0, p1)
colJob := newLoadColJob()
resp, err = suite.server.LoadCollection(ctx, colJob)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertPartitionLoaded(ctx, collection, p2)
}
func (suite *ServiceSuite) TestLoadPartitionWithReplicas() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
// Test load partitions
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
// Load with 3 replica
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 11,
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().ErrorIs(merr.CheckRPCCall(resp, err), merr.ErrResourceGroupNodeNotEnough)
}
}
func (suite *ServiceSuite) TestDDLCallbacksReleaseCollection() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
suite.loadAll()
// Test release collection and partition
for _, collection := range suite.collections {
req := &querypb.ReleaseCollectionRequest{
CollectionID: collection,
}
resp, err := suite.server.ReleaseCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertCollectionReleased(collection)
}
// Test release again
for _, collection := range suite.collections {
req := &querypb.ReleaseCollectionRequest{
CollectionID: collection,
}
resp, err := suite.server.ReleaseCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertCollectionReleased(collection)
}
}
func (suite *ServiceSuite) TestDDLCallbacksReleasePartition() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
suite.loadAll()
// Test release partition
for _, collection := range suite.collections {
req := &querypb.ReleasePartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := suite.server.ReleasePartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertPartitionReleased(collection, suite.partitions[collection]...)
}
// Test release again
for _, collection := range suite.collections {
req := &querypb.ReleasePartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := suite.server.ReleasePartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertPartitionReleased(collection, suite.partitions[collection]...)
}
// Test release partial partitions
suite.releaseAll()
suite.loadAll()
for _, collectionID := range suite.collections {
// make collection able to get into loaded state
suite.updateChannelDist(ctx, collectionID)
suite.updateSegmentDist(collectionID, 3000, suite.partitions[collectionID]...)
job.WaitCurrentTargetUpdated(ctx, suite.targetObserver, collectionID)
}
for _, collection := range suite.collections {
req := &querypb.ReleasePartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection][1:],
}
ch := make(chan struct{})
go func() {
defer close(ch)
time.Sleep(100 * time.Millisecond)
suite.updateChannelDist(ctx, collection)
suite.updateSegmentDist(collection, 3000, suite.partitions[collection][:1]...)
}()
resp, err := suite.server.ReleasePartitions(ctx, req)
<-ch
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.True(suite.meta.Exist(ctx, collection))
partitions := suite.meta.GetPartitionsByCollection(ctx, collection)
suite.Len(partitions, 1)
suite.Equal(suite.partitions[collection][0], partitions[0].GetPartitionID())
suite.assertPartitionReleased(collection, suite.partitions[collection][1:]...)
}
}
func (suite *ServiceSuite) TestDynamicRelease() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
col0, col1 := suite.collections[0], suite.collections[1]
p0, p1, p2 := suite.partitions[col0][0], suite.partitions[col0][1], suite.partitions[col0][2]
p3, p4, p5 := suite.partitions[col1][0], suite.partitions[col1][1], suite.partitions[col1][2]
newReleasePartJob := func(col int64, partitions ...int64) *querypb.ReleasePartitionsRequest {
return &querypb.ReleasePartitionsRequest{
CollectionID: col,
PartitionIDs: partitions,
}
}
newReleaseColJob := func(col int64) *querypb.ReleaseCollectionRequest {
return &querypb.ReleaseCollectionRequest{
CollectionID: col,
}
}
// loaded: p0, p1, p2
// action: release p0
// expect: p0 released, p1, p2 loaded
suite.loadAll()
for _, collectionID := range suite.collections {
// make collection able to get into loaded state
suite.updateChannelDist(ctx, collectionID)
suite.updateSegmentDist(collectionID, 3000, suite.partitions[collectionID]...)
job.WaitCurrentTargetUpdated(ctx, suite.targetObserver, collectionID)
}
req := newReleasePartJob(col0, p0)
// update segments
ch := make(chan struct{})
go func() {
defer close(ch)
time.Sleep(100 * time.Millisecond)
suite.updateSegmentDist(col0, 3000, p1, p2)
suite.updateChannelDist(ctx, col0)
}()
resp, err := suite.server.ReleasePartitions(ctx, req)
<-ch
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertPartitionReleased(col0, p0)
suite.assertPartitionLoaded(ctx, col0, p1, p2)
// loaded: p1, p2
// action: release p0, p1
// expect: p1 released, p2 loaded
req = newReleasePartJob(col0, p0, p1)
ch = make(chan struct{})
go func() {
defer close(ch)
time.Sleep(100 * time.Millisecond)
suite.updateSegmentDist(col0, 3000, p2)
suite.updateChannelDist(ctx, col0)
}()
resp, err = suite.server.ReleasePartitions(ctx, req)
<-ch
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertPartitionReleased(col0, p0, p1)
suite.assertPartitionLoaded(ctx, col0, p2)
// loaded: p2
// action: release p2
// expect: loadType=col: col loaded, p2 released, full collection should be released.
req = newReleasePartJob(col0, p2)
ch = make(chan struct{})
go func() {
defer close(ch)
time.Sleep(100 * time.Millisecond)
suite.releaseSegmentDist(3000)
suite.releaseAllChannelDist()
}()
resp, err = suite.server.ReleasePartitions(ctx, req)
<-ch
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertPartitionReleased(col0, p0, p1, p2)
suite.False(suite.meta.Exist(ctx, col0))
// loaded: p0, p1, p2
// action: release col
// expect: col released
suite.releaseAll()
suite.loadAll()
req2 := newReleaseColJob(col0)
resp, err = suite.server.ReleaseCollection(ctx, req2)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertCollectionReleased(col0)
suite.assertPartitionReleased(col0, p0, p1, p2)
// loaded: p3, p4, p5
// action: release p3, p4, p5
// expect: loadType=partition: col released
suite.releaseAll()
suite.loadAll()
req = newReleasePartJob(col1, p3, p4, p5)
resp, err = suite.server.ReleasePartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertCollectionReleased(col1)
suite.assertPartitionReleased(col1, p3, p4, p5)
}
func (suite *ServiceSuite) releaseAll() {
ctx := context.Background()
for _, collection := range suite.collections {
resp, err := suite.server.ReleaseCollection(ctx, &querypb.ReleaseCollectionRequest{
CollectionID: collection,
})
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.assertCollectionReleased(collection)
}
}
func (suite *ServiceSuite) assertCollectionReleased(collection int64) {
ctx := context.Background()
suite.False(suite.meta.Exist(ctx, collection))
suite.Equal(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection)))
for _, channel := range suite.channels[collection] {
suite.Nil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget))
}
for _, partitions := range suite.segments[collection] {
for _, segment := range partitions {
suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}
func (suite *ServiceSuite) assertPartitionReleased(collection int64, partitionIDs ...int64) {
ctx := context.Background()
for _, partition := range partitionIDs {
suite.Nil(suite.meta.GetPartition(ctx, partition))
segments := suite.segments[collection][partition]
for _, segment := range segments {
suite.Nil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}
func (suite *ServiceSuite) TestDDLCallbacksLoadCollectionWithUserSpecifiedReplicaMode() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
// Test load collection with userSpecifiedReplicaMode = true
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: 1,
}
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
// Verify UserSpecifiedReplicaMode is set correctly
loadedCollection := suite.meta.GetCollection(ctx, collection)
suite.NotNil(loadedCollection)
suite.True(loadedCollection.GetUserSpecifiedReplicaMode())
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
}
func (suite *ServiceSuite) TestLoadPartitionWithUserSpecifiedReplicaMode() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
// Test load partition with userSpecifiedReplicaMode = true
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
ReplicaNumber: 1,
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
// Verify UserSpecifiedReplicaMode is set correctly
loadedCollection := suite.meta.GetCollection(ctx, collection)
suite.NotNil(loadedCollection)
suite.True(loadedCollection.GetUserSpecifiedReplicaMode())
suite.targetMgr.UpdateCollectionCurrentTarget(ctx, collection)
suite.assertCollectionLoaded(collection)
}
}
func (suite *ServiceSuite) TestLoadPartitionUpdateUserSpecifiedReplicaMode() {
ctx := context.Background()
suite.expectGetRecoverInfoForAllCollections()
// First load partition with userSpecifiedReplicaMode = false
collection := suite.collections[1] // Use partition load type collection
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
return
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection][:1], // Load first partition
}
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
// Verify UserSpecifiedReplicaMode is false
loadedCollection := suite.meta.GetCollection(ctx, collection)
suite.NotNil(loadedCollection)
suite.False(loadedCollection.GetUserSpecifiedReplicaMode())
// Load another partition with userSpecifiedReplicaMode = true
req2 := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection][1:2], // Load second partition
ReplicaNumber: 1,
}
resp, err = suite.server.LoadPartitions(ctx, req2)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
// Verify UserSpecifiedReplicaMode is updated to true
updatedCollection := suite.meta.GetCollection(ctx, collection)
suite.NotNil(updatedCollection)
suite.True(updatedCollection.GetUserSpecifiedReplicaMode())
}
func (suite *ServiceSuite) TestSyncNewCreatedPartition() {
newPartition := int64(999)
ctx := context.Background()
// test sync new created partition
suite.loadAll()
collectionID := suite.collections[0]
// make collection able to get into loaded state
suite.updateChannelDist(ctx, collectionID)
suite.updateSegmentDist(collectionID, 3000, suite.partitions[collectionID]...)
req := &querypb.SyncNewCreatedPartitionRequest{
CollectionID: collectionID,
PartitionID: newPartition,
}
syncJob := job.NewSyncNewCreatedPartitionJob(
ctx,
req,
suite.meta,
suite.broker,
suite.targetObserver,
suite.targetMgr,
)
suite.jobScheduler.Add(syncJob)
err := syncJob.Wait()
suite.NoError(err)
partition := suite.meta.CollectionManager.GetPartition(ctx, newPartition)
suite.NotNil(partition)
suite.Equal(querypb.LoadStatus_Loaded, partition.GetStatus())
// test collection not loaded
req = &querypb.SyncNewCreatedPartitionRequest{
CollectionID: int64(888),
PartitionID: newPartition,
}
syncJob = job.NewSyncNewCreatedPartitionJob(
ctx,
req,
suite.meta,
suite.broker,
suite.targetObserver,
suite.targetMgr,
)
suite.jobScheduler.Add(syncJob)
err = syncJob.Wait()
suite.NoError(err)
// test collection loaded, but its loadType is loadPartition
req = &querypb.SyncNewCreatedPartitionRequest{
CollectionID: suite.collections[1],
PartitionID: newPartition,
}
syncJob = job.NewSyncNewCreatedPartitionJob(
ctx,
req,
suite.meta,
suite.broker,
suite.targetObserver,
suite.targetMgr,
)
suite.jobScheduler.Add(syncJob)
err = syncJob.Wait()
suite.NoError(err)
}
func (suite *ServiceSuite) assertCollectionLoaded(collection int64) {
ctx := context.Background()
suite.True(suite.meta.Exist(ctx, collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(ctx, collection)))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(ctx, collection, channel, meta.CurrentTarget))
}
for _, segments := range suite.segments[collection] {
for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetSealedSegment(ctx, collection, segment, meta.CurrentTarget))
}
}
}

View File

@ -27,7 +27,6 @@ import (
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
@ -36,29 +35,27 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type LoadCollectionJob struct {
*BaseJob
req *querypb.LoadCollectionRequest
undo *UndoList
dist *meta.DistributionManager
meta *meta.Meta
broker meta.Broker
targetMgr meta.TargetManagerInterface
targetObserver *observers.TargetObserver
collectionObserver *observers.CollectionObserver
nodeMgr *session.NodeManager
collInfo *milvuspb.DescribeCollectionResponse
userSpecifiedReplicaMode bool
result message.BroadcastResultAlterLoadConfigMessageV2
undo *UndoList
dist *meta.DistributionManager
meta *meta.Meta
broker meta.Broker
targetMgr meta.TargetManagerInterface
targetObserver *observers.TargetObserver
collectionObserver *observers.CollectionObserver
nodeMgr *session.NodeManager
}
func NewLoadCollectionJob(
ctx context.Context,
req *querypb.LoadCollectionRequest,
result message.BroadcastResultAlterLoadConfigMessageV2,
dist *meta.DistributionManager,
meta *meta.Meta,
broker meta.Broker,
@ -66,130 +63,60 @@ func NewLoadCollectionJob(
targetObserver *observers.TargetObserver,
collectionObserver *observers.CollectionObserver,
nodeMgr *session.NodeManager,
userSpecifiedReplicaMode bool,
) *LoadCollectionJob {
return &LoadCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
undo: NewUndoList(ctx, meta, targetMgr, targetObserver),
dist: dist,
meta: meta,
broker: broker,
targetMgr: targetMgr,
targetObserver: targetObserver,
collectionObserver: collectionObserver,
nodeMgr: nodeMgr,
userSpecifiedReplicaMode: userSpecifiedReplicaMode,
BaseJob: NewBaseJob(ctx, 0, result.Message.Header().GetCollectionId()),
result: result,
undo: NewUndoList(ctx, meta, targetMgr, targetObserver),
dist: dist,
meta: meta,
broker: broker,
targetMgr: targetMgr,
targetObserver: targetObserver,
collectionObserver: collectionObserver,
nodeMgr: nodeMgr,
}
}
func (job *LoadCollectionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
if len(req.GetResourceGroups()) == 0 {
req.ResourceGroups = []string{meta.DefaultResourceGroupName}
}
var err error
job.collInfo, err = job.broker.DescribeCollection(job.ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to describe collection from RootCoord", zap.Error(err))
return err
}
collection := job.meta.GetCollection(job.ctx, req.GetCollectionID())
if collection == nil {
return nil
}
if collection.GetReplicaNumber() != req.GetReplicaNumber() {
msg := fmt.Sprintf("collection with different replica number %d existed, release this collection first before changing its replica number",
job.meta.GetReplicaNumber(job.ctx, req.GetCollectionID()),
)
log.Warn(msg)
return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded collection")
}
collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(job.ctx, collection.GetCollectionID()).Collect()
left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups())
if len(left) > 0 || len(right) > 0 {
msg := fmt.Sprintf("collection with different resource groups %v existed, release this collection first before changing its resource groups",
collectionUsedRG)
log.Warn(msg)
return merr.WrapErrParameterInvalid(collectionUsedRG, req.GetResourceGroups(), "can't change the resource groups for loaded partitions")
}
return nil
}
func (job *LoadCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
req := job.result.Message.Header()
vchannels := job.result.GetVChannelsWithoutControlChannel()
// 1. Fetch target partitions
partitionIDs, err := job.broker.GetPartitions(job.ctx, req.GetCollectionID())
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionId()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionId())
// 1. create replica if not exist
if _, err := utils.SpawnReplicasWithReplicaConfig(job.ctx, job.meta, meta.SpawnWithReplicaConfigParams{
CollectionID: req.GetCollectionId(),
Channels: vchannels,
Configs: req.GetReplicas(),
}); err != nil {
return err
}
collInfo, err := job.broker.DescribeCollection(job.ctx, req.GetCollectionId())
if err != nil {
msg := "failed to get partitions from RootCoord"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
return err
}
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
lackPartitionIDs := lo.FilterMap(partitionIDs, func(partID int64, _ int) (int64, bool) {
return partID, !lo.Contains(loadedPartitionIDs, partID)
})
if len(lackPartitionIDs) == 0 {
return nil
}
job.undo.CollectionID = req.GetCollectionID()
job.undo.LackPartitions = lackPartitionIDs
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
colExisted := job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID())
if !colExisted {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
// 2. put load info meta
fieldIndexIDs := make(map[int64]int64, len(req.GetLoadFields()))
fieldIDs := make([]int64, len(req.GetLoadFields()))
for _, loadField := range req.GetLoadFields() {
if loadField.GetIndexId() != 0 {
fieldIndexIDs[loadField.GetFieldId()] = loadField.GetIndexId()
}
fieldIDs = append(fieldIDs, loadField.GetFieldId())
}
// 2. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(job.ctx, req.GetCollectionID())
if len(replicas) == 0 {
// API of LoadCollection is wired, we should use map[resourceGroupNames]replicaNumber as input, to keep consistency with `TransferReplica` API.
// Then we can implement dynamic replica changed in different resource group independently.
_, err = utils.SpawnReplicasWithRG(job.ctx, job.meta, req.GetCollectionID(), req.GetResourceGroups(),
req.GetReplicaNumber(), job.collInfo.GetVirtualChannelNames(), req.GetPriority())
if err != nil {
msg := "failed to spawn replica for collection"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
job.undo.IsReplicaCreated = true
}
// 4. put collection/partitions meta
partitions := lo.Map(lackPartitionIDs, func(partID int64, _ int) *meta.Partition {
replicaNumber := int32(len(req.GetReplicas()))
partitions := lo.Map(req.GetPartitionIds(), func(partID int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
CollectionID: req.GetCollectionId(),
PartitionID: partID,
ReplicaNumber: req.GetReplicaNumber(),
ReplicaNumber: replicaNumber,
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
FieldIndexID: fieldIndexIDs,
},
CreatedAt: time.Now(),
}
@ -198,22 +125,35 @@ func (job *LoadCollectionJob) Execute() error {
ctx, sp := otel.Tracer(typeutil.QueryCoordRole).Start(job.ctx, "LoadCollection", trace.WithNewRoot())
collection := &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
CollectionID: req.GetCollectionId(),
ReplicaNumber: replicaNumber,
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
FieldIndexID: fieldIndexIDs,
LoadType: querypb.LoadType_LoadCollection,
LoadFields: req.GetLoadFields(),
DbID: job.collInfo.GetDbId(),
UserSpecifiedReplicaMode: job.userSpecifiedReplicaMode,
LoadFields: fieldIDs,
DbID: req.GetDbId(),
UserSpecifiedReplicaMode: req.GetUserSpecifiedReplicaMode(),
},
CreatedAt: time.Now(),
LoadSpan: sp,
Schema: job.collInfo.GetSchema(),
Schema: collInfo.GetSchema(),
}
job.undo.IsNewCollection = true
err = job.meta.CollectionManager.PutCollection(job.ctx, collection, partitions...)
if err != nil {
incomingPartitions := typeutil.NewSet(req.GetPartitionIds()...)
currentPartitions := job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionId())
toReleasePartitions := make([]int64, 0)
for _, partition := range currentPartitions {
if !incomingPartitions.Contain(partition.GetPartitionID()) {
toReleasePartitions = append(toReleasePartitions, partition.GetPartitionID())
}
}
if len(toReleasePartitions) > 0 {
job.targetObserver.ReleasePartition(req.GetCollectionId(), toReleasePartitions...)
if err := job.meta.CollectionManager.RemovePartition(job.ctx, req.GetCollectionId(), toReleasePartitions...); err != nil {
return errors.Wrap(err, "failed to remove partitions")
}
}
if err = job.meta.CollectionManager.PutCollection(job.ctx, collection, partitions...); err != nil {
msg := "failed to store collection and partitions"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
@ -222,233 +162,11 @@ func (job *LoadCollectionJob) Execute() error {
metrics.QueryCoordNumPartitions.WithLabelValues().Add(float64(len(partitions)))
// 5. update next target, no need to rollback if pull target failed, target observer will pull target in periodically
_, err = job.targetObserver.UpdateNextTarget(req.GetCollectionID())
if err != nil {
msg := "failed to update next target"
log.Warn(msg, zap.Error(err))
}
job.undo.IsTargetUpdated = true
// 6. register load task into collection observer
job.collectionObserver.LoadCollection(ctx, req.GetCollectionID())
return nil
}
func (job *LoadCollectionJob) PostExecute() {
if job.Error() != nil {
job.undo.RollBack()
}
}
type LoadPartitionJob struct {
*BaseJob
req *querypb.LoadPartitionsRequest
undo *UndoList
dist *meta.DistributionManager
meta *meta.Meta
broker meta.Broker
targetMgr meta.TargetManagerInterface
targetObserver *observers.TargetObserver
collectionObserver *observers.CollectionObserver
nodeMgr *session.NodeManager
collInfo *milvuspb.DescribeCollectionResponse
userSpecifiedReplicaMode bool
}
func NewLoadPartitionJob(
ctx context.Context,
req *querypb.LoadPartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
broker meta.Broker,
targetMgr meta.TargetManagerInterface,
targetObserver *observers.TargetObserver,
collectionObserver *observers.CollectionObserver,
nodeMgr *session.NodeManager,
userSpecifiedReplicaMode bool,
) *LoadPartitionJob {
return &LoadPartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
undo: NewUndoList(ctx, meta, targetMgr, targetObserver),
dist: dist,
meta: meta,
broker: broker,
targetMgr: targetMgr,
targetObserver: targetObserver,
collectionObserver: collectionObserver,
nodeMgr: nodeMgr,
userSpecifiedReplicaMode: userSpecifiedReplicaMode,
}
}
func (job *LoadPartitionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
if len(req.GetResourceGroups()) == 0 {
req.ResourceGroups = []string{meta.DefaultResourceGroupName}
}
var err error
job.collInfo, err = job.broker.DescribeCollection(job.ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to describe collection from RootCoord", zap.Error(err))
if _, err = job.targetObserver.UpdateNextTarget(req.GetCollectionId()); err != nil {
return err
}
collection := job.meta.GetCollection(job.ctx, req.GetCollectionID())
if collection == nil {
return nil
}
if collection.GetReplicaNumber() != req.GetReplicaNumber() {
msg := "collection with different replica number existed, release this collection first before changing its replica number"
log.Warn(msg)
return merr.WrapErrParameterInvalid(collection.GetReplicaNumber(), req.GetReplicaNumber(), "can't change the replica number for loaded partitions")
}
collectionUsedRG := job.meta.ReplicaManager.GetResourceGroupByCollection(job.ctx, collection.GetCollectionID()).Collect()
left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups())
if len(left) > 0 || len(right) > 0 {
msg := fmt.Sprintf("collection with different resource groups %v existed, release this collection first before changing its resource groups",
collectionUsedRG)
log.Warn(msg)
return merr.WrapErrParameterInvalid(collectionUsedRG, req.GetResourceGroups(), "can't change the resource groups for loaded partitions")
}
// 6. register load task into collection observer
job.collectionObserver.LoadPartitions(ctx, req.GetCollectionId(), incomingPartitions.Collect())
return nil
}
func (job *LoadPartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// 1. Fetch target partitions
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
lackPartitionIDs := lo.FilterMap(req.GetPartitionIDs(), func(partID int64, _ int) (int64, bool) {
return partID, !lo.Contains(loadedPartitionIDs, partID)
})
if len(lackPartitionIDs) == 0 {
return nil
}
job.undo.CollectionID = req.GetCollectionID()
job.undo.LackPartitions = lackPartitionIDs
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
var err error
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
}
// 2. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(context.TODO(), req.GetCollectionID())
if len(replicas) == 0 {
_, err = utils.SpawnReplicasWithRG(job.ctx, job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber(),
job.collInfo.GetVirtualChannelNames(), req.GetPriority())
if err != nil {
msg := "failed to spawn replica for collection"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
job.undo.IsReplicaCreated = true
}
// 4. put collection/partitions meta
partitions := lo.Map(lackPartitionIDs, func(partID int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: partID,
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
},
CreatedAt: time.Now(),
}
})
ctx, sp := otel.Tracer(typeutil.QueryCoordRole).Start(job.ctx, "LoadPartition", trace.WithNewRoot())
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
job.undo.IsNewCollection = true
collection := &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
LoadType: querypb.LoadType_LoadPartition,
LoadFields: req.GetLoadFields(),
DbID: job.collInfo.GetDbId(),
UserSpecifiedReplicaMode: job.userSpecifiedReplicaMode,
},
CreatedAt: time.Now(),
LoadSpan: sp,
Schema: job.collInfo.GetSchema(),
}
err = job.meta.CollectionManager.PutCollection(job.ctx, collection, partitions...)
if err != nil {
msg := "failed to store collection and partitions"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
} else { // collection exists, put partitions only
coll := job.meta.GetCollection(job.ctx, req.GetCollectionID())
if job.userSpecifiedReplicaMode && !coll.CollectionLoadInfo.UserSpecifiedReplicaMode {
coll.CollectionLoadInfo.UserSpecifiedReplicaMode = job.userSpecifiedReplicaMode
err = job.meta.CollectionManager.PutCollection(job.ctx, coll)
if err != nil {
msg := "failed to store collection"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
}
err = job.meta.CollectionManager.PutPartition(job.ctx, partitions...)
if err != nil {
msg := "failed to store partitions"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
}
metrics.QueryCoordNumPartitions.WithLabelValues().Add(float64(len(partitions)))
// 5. update next target, no need to rollback if pull target failed, target observer will pull target in periodically
_, err = job.targetObserver.UpdateNextTarget(req.GetCollectionID())
if err != nil {
msg := "failed to update next target"
log.Warn(msg, zap.Error(err))
}
job.undo.IsTargetUpdated = true
job.collectionObserver.LoadPartitions(ctx, req.GetCollectionID(), lackPartitionIDs)
return nil
}
func (job *LoadPartitionJob) PostExecute() {
if job.Error() != nil {
job.undo.RollBack()
}
}

View File

@ -20,7 +20,6 @@ import (
"context"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -30,14 +29,13 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
type ReleaseCollectionJob struct {
*BaseJob
req *querypb.ReleaseCollectionRequest
result message.BroadcastResultDropLoadConfigMessageV2
dist *meta.DistributionManager
meta *meta.Meta
broker meta.Broker
@ -49,7 +47,7 @@ type ReleaseCollectionJob struct {
}
func NewReleaseCollectionJob(ctx context.Context,
req *querypb.ReleaseCollectionRequest,
result message.BroadcastResultDropLoadConfigMessageV2,
dist *meta.DistributionManager,
meta *meta.Meta,
broker meta.Broker,
@ -59,8 +57,8 @@ func NewReleaseCollectionJob(ctx context.Context,
proxyManager proxyutil.ProxyClientManagerInterface,
) *ReleaseCollectionJob {
return &ReleaseCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
BaseJob: NewBaseJob(ctx, 0, result.Message.Header().GetCollectionId()),
result: result,
dist: dist,
meta: meta,
broker: broker,
@ -72,146 +70,40 @@ func NewReleaseCollectionJob(ctx context.Context,
}
func (job *ReleaseCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
collectionID := job.result.Message.Header().GetCollectionId()
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", collectionID))
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
if !job.meta.CollectionManager.Exist(job.ctx, collectionID) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
err := job.meta.CollectionManager.RemoveCollection(job.ctx, req.GetCollectionID())
err := job.meta.CollectionManager.RemoveCollection(job.ctx, collectionID)
if err != nil {
msg := "failed to remove collection"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, collectionID)
if err != nil {
msg := "failed to remove replicas"
log.Warn(msg, zap.Error(err))
}
job.targetObserver.ReleaseCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(collectionID)
// try best discard cache
// shall not affect releasing if failed
job.proxyManager.InvalidateCollectionMetaCache(job.ctx,
&proxypb.InvalidateCollMetaCacheRequest{
CollectionID: req.GetCollectionID(),
CollectionID: collectionID,
},
proxyutil.SetMsgType(commonpb.MsgType_ReleaseCollection))
// try best clean shard leader cache
job.proxyManager.InvalidateShardLeaderCache(job.ctx, &proxypb.InvalidateShardLeaderCacheRequest{
CollectionIDs: []int64{req.GetCollectionID()},
CollectionIDs: []int64{collectionID},
})
waitCollectionReleased(job.dist, job.checkerController, req.GetCollectionID())
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.TotalLabel).Inc()
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc()
return nil
}
type ReleasePartitionJob struct {
*BaseJob
releasePartitionsOnly bool
req *querypb.ReleasePartitionsRequest
dist *meta.DistributionManager
meta *meta.Meta
broker meta.Broker
cluster session.Cluster
targetMgr meta.TargetManagerInterface
targetObserver *observers.TargetObserver
checkerController *checkers.CheckerController
proxyManager proxyutil.ProxyClientManagerInterface
}
func NewReleasePartitionJob(ctx context.Context,
req *querypb.ReleasePartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
broker meta.Broker,
targetMgr meta.TargetManagerInterface,
targetObserver *observers.TargetObserver,
checkerController *checkers.CheckerController,
proxyManager proxyutil.ProxyClientManagerInterface,
) *ReleasePartitionJob {
return &ReleasePartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
broker: broker,
targetMgr: targetMgr,
targetObserver: targetObserver,
checkerController: checkerController,
proxyManager: proxyManager,
}
}
func (job *ReleasePartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
if !job.meta.CollectionManager.Exist(job.ctx, req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(job.ctx, req.GetCollectionID())
toRelease := lo.FilterMap(loadedPartitions, func(partition *meta.Partition, _ int) (int64, bool) {
return partition.GetPartitionID(), lo.Contains(req.GetPartitionIDs(), partition.GetPartitionID())
})
if len(toRelease) == 0 {
log.Warn("releasing partition(s) not loaded")
return nil
}
// If all partitions are released, clear all
if len(toRelease) == len(loadedPartitions) {
log.Info("release partitions covers all partitions, will remove the whole collection")
err := job.meta.CollectionManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
err = job.meta.ReplicaManager.RemoveCollection(job.ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to remove replicas", zap.Error(err))
}
job.targetObserver.ReleaseCollection(req.GetCollectionID())
// try best discard cache
// shall not affect releasing if failed
job.proxyManager.InvalidateCollectionMetaCache(job.ctx,
&proxypb.InvalidateCollMetaCacheRequest{
CollectionID: req.GetCollectionID(),
},
proxyutil.SetMsgType(commonpb.MsgType_ReleaseCollection))
// try best clean shard leader cache
job.proxyManager.InvalidateShardLeaderCache(job.ctx, &proxypb.InvalidateShardLeaderCacheRequest{
CollectionIDs: []int64{req.GetCollectionID()},
})
waitCollectionReleased(job.dist, job.checkerController, req.GetCollectionID())
} else {
err := job.meta.CollectionManager.RemovePartition(job.ctx, req.GetCollectionID(), toRelease...)
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return errors.Wrap(err, msg)
}
job.targetObserver.ReleasePartition(req.GetCollectionID(), toRelease...)
// wait current target updated, so following querys will act as expected
waitCurrentTargetUpdated(job.ctx, job.targetObserver, job.req.GetCollectionID())
waitCollectionReleased(job.dist, job.checkerController, req.GetCollectionID(), toRelease...)
}
return nil
}

View File

@ -96,5 +96,5 @@ func (job *SyncNewCreatedPartitionJob) Execute() error {
return errors.Wrap(err, msg)
}
return waitCurrentTargetUpdated(job.ctx, job.targetObserver, job.req.GetCollectionID())
return WaitCurrentTargetUpdated(job.ctx, job.targetObserver, job.req.GetCollectionID())
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,248 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"sort"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
var ErrIgnoredAlterLoadConfig = errors.New("ignored alter load config")
type AlterLoadConfigRequest struct {
Meta *meta.Meta
CollectionInfo *milvuspb.DescribeCollectionResponse
Expected ExpectedLoadConfig
Current CurrentLoadConfig
}
// CheckIfLoadPartitionsExecutable checks if the load partitions is executable.
func (req *AlterLoadConfigRequest) CheckIfLoadPartitionsExecutable() error {
if req.Current.Collection == nil {
return nil
}
expectedReplicaNumber := 0
for _, num := range req.Expected.ExpectedReplicaNumber {
expectedReplicaNumber += num
}
if len(req.Current.Replicas) != expectedReplicaNumber {
return merr.WrapErrParameterInvalid(len(req.Current.Replicas), expectedReplicaNumber, "can't change the replica number for loaded partitions")
}
return nil
}
type ExpectedLoadConfig struct {
ExpectedPartitionIDs []int64
ExpectedReplicaNumber map[string]int // map resource group name to replica number in resource group
ExpectedFieldIndexID map[int64]int64
ExpectedLoadFields []int64
ExpectedPriority commonpb.LoadPriority
ExpectedUserSpecifiedReplicaMode bool
}
type CurrentLoadConfig struct {
Collection *meta.Collection
Partitions map[int64]*meta.Partition
Replicas map[int64]*meta.Replica
}
func (c *CurrentLoadConfig) GetLoadPriority() commonpb.LoadPriority {
for _, replica := range c.Replicas {
return replica.LoadPriority()
}
return commonpb.LoadPriority_HIGH
}
func (c *CurrentLoadConfig) GetFieldIndexID() map[int64]int64 {
return c.Collection.FieldIndexID
}
func (c *CurrentLoadConfig) GetLoadFields() []int64 {
return c.Collection.LoadFields
}
func (c *CurrentLoadConfig) GetUserSpecifiedReplicaMode() bool {
return c.Collection.UserSpecifiedReplicaMode
}
func (c *CurrentLoadConfig) GetReplicaNumber() map[string]int {
replicaNumber := make(map[string]int)
for _, replica := range c.Replicas {
replicaNumber[replica.GetResourceGroup()]++
}
return replicaNumber
}
func (c *CurrentLoadConfig) GetPartitionIDs() []int64 {
partitionIDs := make([]int64, 0, len(c.Partitions))
for _, partition := range c.Partitions {
partitionIDs = append(partitionIDs, partition.GetPartitionID())
}
return partitionIDs
}
// IntoLoadConfigMessageHeader converts the current load config into a load config message header.
func (c *CurrentLoadConfig) IntoLoadConfigMessageHeader() *messagespb.AlterLoadConfigMessageHeader {
if c.Collection == nil {
return nil
}
partitionIDs := make([]int64, 0, len(c.Partitions))
partitionIDs = append(partitionIDs, c.GetPartitionIDs()...)
sort.Slice(partitionIDs, func(i, j int) bool {
return partitionIDs[i] < partitionIDs[j]
})
loadFields := generateLoadFields(c.GetLoadFields(), c.GetFieldIndexID())
replicas := make([]*messagespb.LoadReplicaConfig, 0, len(c.Replicas))
for _, replica := range c.Replicas {
replicas = append(replicas, &messagespb.LoadReplicaConfig{
ReplicaId: replica.GetID(),
ResourceGroupName: replica.GetResourceGroup(),
Priority: replica.LoadPriority(),
})
}
sort.Slice(replicas, func(i, j int) bool {
return replicas[i].GetReplicaId() < replicas[j].GetReplicaId()
})
return &messagespb.AlterLoadConfigMessageHeader{
DbId: c.Collection.DbID,
CollectionId: c.Collection.CollectionID,
PartitionIds: partitionIDs,
LoadFields: loadFields,
Replicas: replicas,
UserSpecifiedReplicaMode: c.GetUserSpecifiedReplicaMode(),
}
}
// GenerateAlterLoadConfigMessage generates the alter load config message for the collection.
func GenerateAlterLoadConfigMessage(ctx context.Context, req *AlterLoadConfigRequest) (message.BroadcastMutableMessage, error) {
loadFields := generateLoadFields(req.Expected.ExpectedLoadFields, req.Expected.ExpectedFieldIndexID)
loadReplicaConfigs, err := req.generateReplicas(ctx)
if err != nil {
return nil, err
}
partitionIDs := make([]int64, 0, len(req.Expected.ExpectedPartitionIDs))
partitionIDs = append(partitionIDs, req.Expected.ExpectedPartitionIDs...)
sort.Slice(partitionIDs, func(i, j int) bool {
return partitionIDs[i] < partitionIDs[j]
})
header := &messagespb.AlterLoadConfigMessageHeader{
DbId: req.CollectionInfo.DbId,
CollectionId: req.CollectionInfo.CollectionID,
PartitionIds: partitionIDs,
LoadFields: loadFields,
Replicas: loadReplicaConfigs,
UserSpecifiedReplicaMode: req.Expected.ExpectedUserSpecifiedReplicaMode,
}
// check if the load configuration is changed
if previousHeader := req.Current.IntoLoadConfigMessageHeader(); proto.Equal(previousHeader, header) {
return nil, ErrIgnoredAlterLoadConfig
}
return message.NewAlterLoadConfigMessageBuilderV2().
WithHeader(header).
WithBody(&messagespb.AlterLoadConfigMessageBody{}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}).
MustBuildBroadcast(), nil
}
// generateLoadFields generates the load fields for the collection.
func generateLoadFields(loadedFields []int64, fieldIndexID map[int64]int64) []*messagespb.LoadFieldConfig {
loadFields := lo.Map(loadedFields, func(fieldID int64, _ int) *messagespb.LoadFieldConfig {
if indexID, ok := fieldIndexID[fieldID]; ok {
return &messagespb.LoadFieldConfig{
FieldId: fieldID,
IndexId: indexID,
}
}
return &messagespb.LoadFieldConfig{
FieldId: fieldID,
IndexId: 0,
}
})
sort.Slice(loadFields, func(i, j int) bool {
return loadFields[i].GetFieldId() < loadFields[j].GetFieldId()
})
return loadFields
}
// generateReplicas generates the replicas for the collection.
func (req *AlterLoadConfigRequest) generateReplicas(ctx context.Context) ([]*messagespb.LoadReplicaConfig, error) {
// fill up the existsReplicaNum found the redundant replicas and the replicas that should be kept
existsReplicaNum := make(map[string]int)
keptReplicas := make(map[int64]struct{}) // replica that should be kept
redundantReplicas := make([]int64, 0) // replica that should be removed
loadReplicaConfigs := make([]*messagespb.LoadReplicaConfig, 0)
for _, replica := range req.Current.Replicas {
if existsReplicaNum[replica.GetResourceGroup()] >= req.Expected.ExpectedReplicaNumber[replica.GetResourceGroup()] {
redundantReplicas = append(redundantReplicas, replica.GetID())
continue
}
keptReplicas[replica.GetID()] = struct{}{}
loadReplicaConfigs = append(loadReplicaConfigs, &messagespb.LoadReplicaConfig{
ReplicaId: replica.GetID(),
ResourceGroupName: replica.GetResourceGroup(),
Priority: replica.LoadPriority(),
})
existsReplicaNum[replica.GetResourceGroup()]++
}
// check if there should generate new incoming replicas.
for rg, num := range req.Expected.ExpectedReplicaNumber {
for i := existsReplicaNum[rg]; i < num; i++ {
if len(redundantReplicas) > 0 {
// reuse the replica from redundant replicas.
// make a transfer operation from a resource group to another resource group.
replicaID := redundantReplicas[0]
redundantReplicas = redundantReplicas[1:]
loadReplicaConfigs = append(loadReplicaConfigs, &messagespb.LoadReplicaConfig{
ReplicaId: replicaID,
ResourceGroupName: rg,
Priority: req.Expected.ExpectedPriority,
})
} else {
// allocate a new replica.
newID, err := req.Meta.ReplicaManager.AllocateReplicaID(ctx)
if err != nil {
return nil, err
}
loadReplicaConfigs = append(loadReplicaConfigs, &messagespb.LoadReplicaConfig{
ReplicaId: newID,
ResourceGroupName: rg,
Priority: req.Expected.ExpectedPriority,
})
}
}
}
sort.Slice(loadReplicaConfigs, func(i, j int) bool {
return loadReplicaConfigs[i].GetReplicaId() < loadReplicaConfigs[j].GetReplicaId()
})
return loadReplicaConfigs, nil
}

View File

@ -30,10 +30,10 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// waitCollectionReleased blocks until
// WaitCollectionReleased blocks until
// all channels and segments of given collection(partitions) are released,
// empty partition list means wait for collection released
func waitCollectionReleased(dist *meta.DistributionManager, checkerController *checkers.CheckerController, collection int64, partitions ...int64) {
func WaitCollectionReleased(dist *meta.DistributionManager, checkerController *checkers.CheckerController, collection int64, partitions ...int64) {
partitionSet := typeutil.NewUniqueSet(partitions...)
for {
var (
@ -64,7 +64,7 @@ func waitCollectionReleased(dist *meta.DistributionManager, checkerController *c
}
}
func waitCurrentTargetUpdated(ctx context.Context, targetObserver *observers.TargetObserver, collection int64) error {
func WaitCurrentTargetUpdated(ctx context.Context, targetObserver *observers.TargetObserver, collection int64) error {
// manual trigger update next target
ready, err := targetObserver.UpdateNextTarget(collection)
if err != nil {

View File

@ -114,7 +114,7 @@ func NewReplicaWithPriority(replica *querypb.Replica, priority commonpb.LoadPrio
}
func (replica *Replica) LoadPriority() commonpb.LoadPriority {
return replica.loadPriority
return replica.loadPriority // TODO: the load priority doesn't persisted into the replica recovery info.
}
// GetID returns the id of the replica.

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -162,6 +163,70 @@ func (m *ReplicaManager) Get(ctx context.Context, id typeutil.UniqueID) *Replica
return m.replicas[id]
}
type SpawnWithReplicaConfigParams struct {
CollectionID int64
Channels []string
Configs []*messagespb.LoadReplicaConfig
}
// SpawnWithReplicaConfig spawns replicas with replica config.
func (m *ReplicaManager) SpawnWithReplicaConfig(ctx context.Context, params SpawnWithReplicaConfigParams) ([]*Replica, error) {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()
balancePolicy := paramtable.Get().QueryCoordCfg.Balancer.GetValue()
enableChannelExclusiveMode := balancePolicy == ChannelLevelScoreBalancerName
replicas := make([]*Replica, 0)
for _, config := range params.Configs {
replica := NewReplicaWithPriority(&querypb.Replica{
ID: config.GetReplicaId(),
CollectionID: params.CollectionID,
ResourceGroup: config.ResourceGroupName,
}, config.GetPriority())
if enableChannelExclusiveMode {
mutableReplica := replica.CopyForWrite()
mutableReplica.TryEnableChannelExclusiveMode(params.Channels...)
replica = mutableReplica.IntoReplica()
}
replicas = append(replicas, replica)
}
if err := m.put(ctx, replicas...); err != nil {
return nil, errors.Wrap(err, "failed to put replicas")
}
if err := m.removeRedundantReplicas(ctx, params); err != nil {
return nil, errors.Wrap(err, "failed to remove redundant replicas")
}
return replicas, nil
}
// removeRedundantReplicas removes redundant replicas that is not in the new replica config.
func (m *ReplicaManager) removeRedundantReplicas(ctx context.Context, params SpawnWithReplicaConfigParams) error {
existedReplicas, ok := m.coll2Replicas[params.CollectionID]
if !ok {
return nil
}
toRemoveReplicas := make([]int64, 0)
for _, replica := range existedReplicas.replicas {
found := false
replicaID := replica.GetID()
for _, channel := range params.Configs {
if channel.GetReplicaId() == replicaID {
found = true
break
}
}
if !found {
toRemoveReplicas = append(toRemoveReplicas, replicaID)
}
}
return m.removeReplicas(ctx, params.CollectionID, toRemoveReplicas...)
}
// AllocateReplicaID allocates a replica ID.
func (m *ReplicaManager) AllocateReplicaID(ctx context.Context) (int64, error) {
return m.idAllocator()
}
// Spawn spawns N replicas at resource group for given collection in ReplicaManager.
func (m *ReplicaManager) Spawn(ctx context.Context, collection int64, replicaNumInRG map[string]int,
channels []string, loadPriority commonpb.LoadPriority,

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore/mocks"
. "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/pkg/v2/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
@ -116,6 +117,42 @@ func (suite *ReplicaManagerSuite) TearDownTest() {
suite.kv.Close()
}
func (suite *ReplicaManagerSuite) TestSpawnWithReplicaConfig() {
mgr := suite.mgr
ctx := suite.ctx
replicas, err := mgr.SpawnWithReplicaConfig(ctx, SpawnWithReplicaConfigParams{
CollectionID: 100,
Channels: []string{"channel1", "channel2"},
Configs: []*messagespb.LoadReplicaConfig{
{ReplicaId: 1000, ResourceGroupName: "RG1", Priority: commonpb.LoadPriority_HIGH},
},
})
suite.NoError(err)
suite.Len(replicas, 1)
replicas, err = mgr.SpawnWithReplicaConfig(ctx, SpawnWithReplicaConfigParams{
CollectionID: 100,
Channels: []string{"channel1", "channel2"},
Configs: []*messagespb.LoadReplicaConfig{
{ReplicaId: 1000, ResourceGroupName: "RG1", Priority: commonpb.LoadPriority_HIGH},
{ReplicaId: 1001, ResourceGroupName: "RG1", Priority: commonpb.LoadPriority_HIGH},
},
})
suite.NoError(err)
suite.Len(replicas, 2)
replicas, err = mgr.SpawnWithReplicaConfig(ctx, SpawnWithReplicaConfigParams{
CollectionID: 100,
Channels: []string{"channel1", "channel2"},
Configs: []*messagespb.LoadReplicaConfig{
{ReplicaId: 1000, ResourceGroupName: "RG1", Priority: commonpb.LoadPriority_HIGH},
},
})
suite.NoError(err)
suite.Len(replicas, 1)
}
func (suite *ReplicaManagerSuite) TestSpawn() {
mgr := suite.mgr
ctx := suite.ctx

View File

@ -21,6 +21,7 @@ import (
"fmt"
"math/rand"
"os"
"strconv"
"sync"
"testing"
"time"
@ -45,14 +46,15 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -86,8 +88,6 @@ type ServerSuite struct {
ctx context.Context
}
var testMeta string
func (suite *ServerSuite) SetupSuite() {
paramtable.Init()
params.GenerateEtcdConfig()
@ -124,8 +124,9 @@ func (suite *ServerSuite) SetupSuite() {
}
func (suite *ServerSuite) SetupTest() {
initStreamingSystem()
var err error
paramtable.Get().Save(paramtable.Get().MetaStoreCfg.MetaStoreType.Key, testMeta)
suite.tikvCli = tikv.SetupLocalTxn()
suite.server, err = suite.newQueryCoord()
@ -627,11 +628,17 @@ func (suite *ServerSuite) hackServer() {
suite.server.proxyClientManager,
)
suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe()
suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil).Maybe()
suite.broker.EXPECT().ListIndexes(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
suite.broker.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{}, nil).Maybe()
for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
suite.broker.EXPECT().DescribeCollection(mock.Anything, collection).Return(&milvuspb.DescribeCollectionResponse{
DbName: util.DefaultDBName,
DbId: 1,
CollectionID: collection,
CollectionName: "collection_" + strconv.FormatInt(collection, 10),
Schema: &schemapb.CollectionSchema{},
}, nil).Maybe()
suite.expectGetRecoverInfo(collection)
}
log.Debug("server hacked")
@ -670,18 +677,7 @@ func (suite *ServerSuite) newQueryCoord() (*Server, error) {
if err != nil {
return nil, err
}
etcdCli, err := etcd.GetEtcdClient(
Params.EtcdCfg.UseEmbedEtcd.GetAsBool(),
Params.EtcdCfg.EtcdUseSSL.GetAsBool(),
Params.EtcdCfg.Endpoints.GetAsStrings(),
Params.EtcdCfg.EtcdTLSCert.GetValue(),
Params.EtcdCfg.EtcdTLSKey.GetValue(),
Params.EtcdCfg.EtcdTLSCACert.GetValue(),
Params.EtcdCfg.EtcdTLSMinVersion.GetValue())
if err != nil {
return nil, err
}
etcdCli, _ := kvfactory.GetEtcdAndPath()
server.SetEtcdClient(etcdCli)
server.SetTiKVClient(suite.tikvCli)
@ -931,9 +927,5 @@ func createTestSession(nodeID int64, address string, stopping bool) *sessionutil
}
func TestServer(t *testing.T) {
parameters := []string{"tikv", "etcd"}
for _, v := range parameters {
testMeta = v
suite.Run(t, new(ServerSuite))
}
suite.Run(t, new(ServerSuite))
}

View File

@ -194,284 +194,173 @@ func (s *Server) ShowLoadPartitions(ctx context.Context, req *querypb.ShowPartit
}
func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(
logger := log.Ctx(ctx).With(
zap.Int64("dbID", req.GetDbID()),
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int32("replicaNumber", req.GetReplicaNumber()),
zap.Strings("resourceGroups", req.GetResourceGroups()),
zap.Bool("refreshMode", req.GetRefresh()),
)
log.Info("load collection request received",
logger.Info("load collection request received",
zap.Any("schema", req.Schema),
zap.Int64s("fieldIndexes", lo.Values(req.GetFieldIndexID())),
)
metrics.QueryCoordLoadCount.WithLabelValues(metrics.TotalLabel).Inc()
if err := merr.CheckHealthy(s.State()); err != nil {
msg := "failed to load collection"
log.Warn(msg, zap.Error(err))
logger.Warn("failed to load collection", zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
return merr.Status(err), nil
}
// If refresh mode is ON.
if req.GetRefresh() {
err := s.refreshCollection(ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to refresh collection", zap.Error(err))
logger.Warn("failed to refresh collection", zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
logger.Info("refresh collection done")
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
if err := s.broadcastAlterLoadConfigCollectionV2ForLoadCollection(ctx, req); err != nil {
if errors.Is(err, job.ErrIgnoredAlterLoadConfig) {
logger.Info("load collection ignored, collection is already loaded")
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Warn("failed to load collection", zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
// if user specified the replica number in load request, load config changes won't be apply to the collection automatically
userSpecifiedReplicaMode := req.GetReplicaNumber() > 0
// to be compatible with old sdk, which set replica=1 if replica is not specified
// so only both replica and resource groups didn't set in request, it will turn to use the configured load info
if req.GetReplicaNumber() <= 0 && len(req.GetResourceGroups()) == 0 {
// when replica number or resource groups is not set, use pre-defined load config
rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to get pre-defined load info", zap.Error(err))
} else {
if req.GetReplicaNumber() <= 0 && replicas > 0 {
req.ReplicaNumber = int32(replicas)
}
if len(req.GetResourceGroups()) == 0 && len(rgs) > 0 {
req.ResourceGroups = rgs
}
}
}
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1")
req.ReplicaNumber = 1
}
if len(req.GetResourceGroups()) == 0 {
log.Info(fmt.Sprintf("request doesn't indicate the resource groups, set it to %s", meta.DefaultResourceGroupName))
req.ResourceGroups = []string{meta.DefaultResourceGroupName}
}
var loadJob job.Job
collection := s.meta.GetCollection(ctx, req.GetCollectionID())
if collection != nil {
// if collection is loaded, check if collection is loaded with the same replica number and resource groups
// if replica number or resource group changes switch to update load config
collectionUsedRG := s.meta.ReplicaManager.GetResourceGroupByCollection(ctx, collection.GetCollectionID()).Collect()
left, right := lo.Difference(collectionUsedRG, req.GetResourceGroups())
rgChanged := len(left) > 0 || len(right) > 0
replicaChanged := collection.GetReplicaNumber() != req.GetReplicaNumber()
if replicaChanged || rgChanged {
log.Warn("collection is loaded with different replica number or resource group, switch to update load config",
zap.Int32("oldReplicaNumber", collection.GetReplicaNumber()),
zap.Strings("oldResourceGroups", collectionUsedRG))
updateReq := &querypb.UpdateLoadConfigRequest{
CollectionIDs: []int64{req.GetCollectionID()},
ReplicaNumber: req.GetReplicaNumber(),
ResourceGroups: req.GetResourceGroups(),
}
loadJob = job.NewUpdateLoadConfigJob(
ctx,
updateReq,
s.meta,
s.targetMgr,
s.targetObserver,
s.collectionObserver,
userSpecifiedReplicaMode,
)
}
}
if loadJob == nil {
loadJob = job.NewLoadCollectionJob(ctx,
req,
s.dist,
s.meta,
s.broker,
s.targetMgr,
s.targetObserver,
s.collectionObserver,
s.nodeMgr,
userSpecifiedReplicaMode,
)
}
s.jobScheduler.Add(loadJob)
err := loadJob.Wait()
if err != nil {
msg := "failed to load collection"
log.Warn(msg, zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
}
logger.Info("load collection done")
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
func (s *Server) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
logger := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
log.Info("release collection request received")
logger.Info("release collection request received")
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("release-collection")
if err := merr.CheckHealthy(s.State()); err != nil {
msg := "failed to release collection"
log.Warn(msg, zap.Error(err))
logger.Warn("failed to release collection", zap.Error(err))
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
return merr.Status(err), nil
}
releaseJob := job.NewReleaseCollectionJob(ctx,
req,
s.dist,
s.meta,
s.broker,
s.targetMgr,
s.targetObserver,
s.checkerController,
s.proxyClientManager,
)
s.jobScheduler.Add(releaseJob)
err := releaseJob.Wait()
if err != nil {
msg := "failed to release collection"
log.Warn(msg, zap.Error(err))
if err := s.broadcastDropLoadConfigCollectionV2ForReleaseCollection(ctx, req); err != nil {
if errors.Is(err, errReleaseCollectionNotLoaded) {
logger.Info("release collection ignored, collection is not loaded")
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Warn("failed to release collection", zap.Error(err))
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
return merr.Status(err), nil
}
log.Info("collection released")
job.WaitCollectionReleased(s.dist, s.checkerController, req.GetCollectionID())
logger.Info("release collection done")
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
return merr.Success(), nil
}
func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(
logger := log.Ctx(ctx).With(
zap.Int64("dbID", req.GetDbID()),
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int32("replicaNumber", req.GetReplicaNumber()),
zap.Int64s("partitions", req.GetPartitionIDs()),
zap.Strings("resourceGroups", req.GetResourceGroups()),
zap.Bool("refreshMode", req.GetRefresh()),
)
log.Info("received load partitions request",
zap.Any("schema", req.Schema),
zap.Int64s("partitions", req.GetPartitionIDs()))
logger.Info("received load partitions request",
zap.Any("schema", req.Schema))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.TotalLabel).Inc()
if err := merr.CheckHealthy(s.State()); err != nil {
msg := "failed to load partitions"
log.Warn(msg, zap.Error(err))
logger.Warn("failed to load partitions", zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
return merr.Status(err), nil
}
// If refresh mode is ON.
if req.GetRefresh() {
err := s.refreshCollection(ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to refresh partitions", zap.Error(err))
logger.Warn("failed to refresh partitions", zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
logger.Info("refresh partitions done")
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
if err := s.broadcastAlterLoadConfigCollectionV2ForLoadPartitions(ctx, req); err != nil {
if errors.Is(err, job.ErrIgnoredAlterLoadConfig) {
logger.Info("load partitions ignored, partitions are already loaded")
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Warn("failed to load partitions", zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
// if user specified the replica number in load request, load config changes won't be apply to the collection automatically
userSpecifiedReplicaMode := req.GetReplicaNumber() > 0
// to be compatible with old sdk, which set replica=1 if replica is not specified
// so only both replica and resource groups didn't set in request, it will turn to use the configured load info
if req.GetReplicaNumber() <= 0 && len(req.GetResourceGroups()) == 0 {
// when replica number or resource groups is not set, use database level config
rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, req.GetCollectionID())
if err != nil {
log.Warn("failed to get data base level load info", zap.Error(err))
}
if req.GetReplicaNumber() <= 0 {
log.Info("load collection use database level replica number", zap.Int64("databaseLevelReplicaNum", replicas))
req.ReplicaNumber = int32(replicas)
}
if len(req.GetResourceGroups()) == 0 {
log.Info("load collection use database level resource groups", zap.Strings("databaseLevelResourceGroups", rgs))
req.ResourceGroups = rgs
}
}
loadJob := job.NewLoadPartitionJob(ctx,
req,
s.dist,
s.meta,
s.broker,
s.targetMgr,
s.targetObserver,
s.collectionObserver,
s.nodeMgr,
userSpecifiedReplicaMode,
)
s.jobScheduler.Add(loadJob)
err := loadJob.Wait()
if err != nil {
msg := "failed to load partitions"
log.Warn(msg, zap.Error(err))
metrics.QueryCoordLoadCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
}
logger.Info("load partitions done")
metrics.QueryCoordLoadCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(
logger := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
log.Info("release partitions", zap.Int64s("partitions", req.GetPartitionIDs()))
logger.Info("release partitions")
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.TotalLabel).Inc()
if err := merr.CheckHealthy(s.State()); err != nil {
msg := "failed to release partitions"
log.Warn(msg, zap.Error(err))
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
}
if len(req.GetPartitionIDs()) == 0 {
err := merr.WrapErrParameterInvalid("any partition", "empty partition list")
log.Warn("no partition to release", zap.Error(err))
logger.Warn("failed to release partitions", zap.Error(err))
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
tr := timerecord.NewTimeRecorder("release-partitions")
releaseJob := job.NewReleasePartitionJob(ctx,
req,
s.dist,
s.meta,
s.broker,
s.targetMgr,
s.targetObserver,
s.checkerController,
s.proxyClientManager,
)
s.jobScheduler.Add(releaseJob)
err := releaseJob.Wait()
if err != nil {
msg := "failed to release partitions"
log.Warn(msg, zap.Error(err))
if len(req.GetPartitionIDs()) == 0 {
err := merr.WrapErrParameterInvalid("any partition", "empty partition list")
logger.Warn("no partition to release", zap.Error(err))
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(errors.Wrap(err, msg)), nil
return merr.Status(err), nil
}
collectionReleased, err := s.broadcastAlterLoadConfigCollectionV2ForReleasePartitions(ctx, req)
if err != nil {
if errors.Is(err, job.ErrIgnoredAlterLoadConfig) {
logger.Info("release partitions ignored, partitions are already released")
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Warn("failed to release partitions", zap.Error(err))
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if collectionReleased {
job.WaitCollectionReleased(s.dist, s.checkerController, req.GetCollectionID())
} else {
job.WaitCurrentTargetUpdated(ctx, s.targetObserver, req.GetCollectionID())
job.WaitCollectionReleased(s.dist, s.checkerController, req.GetCollectionID(), req.GetPartitionIDs()...)
}
logger.Info("release partitions done", zap.Bool("collectionReleased", collectionReleased))
metrics.QueryCoordReleaseCount.WithLabelValues(metrics.SuccessLabel).Inc()
metrics.QueryCoordReleaseLatency.WithLabelValues().Observe(float64(tr.ElapseSpan().Milliseconds()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
return merr.Success(), nil
}

View File

@ -18,6 +18,7 @@ package querycoordv2
import (
"context"
"fmt"
"sort"
"testing"
"time"
@ -32,6 +33,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/rgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/json"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
@ -61,11 +63,13 @@ import (
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -122,7 +126,9 @@ func initStreamingSystem() {
LastConfirmedMessageID: walimplstest.NewTestMessageID(1),
}
}
registry.CallMessageAckCallback(context.Background(), msg, results)
retry.Do(context.Background(), func() error {
return registry.CallMessageAckCallback(context.Background(), msg, results)
}, retry.AttemptAlways())
return &types.BroadcastAppendResult{}, nil
})
bapi.EXPECT().Close().Return()
@ -131,8 +137,7 @@ func initStreamingSystem() {
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().Close().Return()
broadcast.Release()
mb.EXPECT().Close().Return().Maybe()
broadcast.ResetBroadcaster()
broadcast.Register(mb)
}
@ -144,8 +149,8 @@ func (suite *ServiceSuite) SetupSuite() {
suite.collections = []int64{1000, 1001}
suite.partitions = map[int64][]int64{
1000: {100, 101},
1001: {102, 103},
1000: {100, 101, 102},
1001: {103, 104, 105},
}
suite.channels = map[int64][]string{
1000: {"1000-dmc0", "1000-dmc1"},
@ -155,16 +160,19 @@ func (suite *ServiceSuite) SetupSuite() {
1000: {
100: {1, 2},
101: {3, 4},
102: {5, 6},
},
1001: {
102: {5, 6},
103: {7, 8},
104: {9, 10},
105: {11, 12},
},
}
suite.loadTypes = map[int64]querypb.LoadType{
1000: querypb.LoadType_LoadCollection,
1001: querypb.LoadType_LoadPartition,
}
suite.replicaNumber = map[int64]int32{
1000: 1,
1001: 3,
@ -280,7 +288,35 @@ func (suite *ServiceSuite) SetupTest() {
suite.server.UpdateStateCode(commonpb.StateCode_Healthy)
suite.broker.EXPECT().GetCollectionLoadInfo(mock.Anything, mock.Anything).Return([]string{meta.DefaultResourceGroupName}, 1, nil).Maybe()
suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, collectionID int64) (*milvuspb.DescribeCollectionResponse, error) {
for _, collection := range suite.collections {
if collection == collectionID {
return &milvuspb.DescribeCollectionResponse{
DbName: util.DefaultDBName,
DbId: 1,
CollectionID: collectionID,
CollectionName: fmt.Sprintf("collection_%d", collectionID),
Schema: &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100},
{FieldID: 101},
{FieldID: 102},
},
},
}, nil
}
}
return &milvuspb.DescribeCollectionResponse{
Status: merr.Status(merr.ErrCollectionNotFound),
}, nil
}).Maybe()
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, collectionID int64) ([]int64, error) {
partitionIDs, ok := suite.partitions[collectionID]
if !ok {
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
return partitionIDs, nil
}).Maybe()
registry.ResetRegistration()
RegisterDDLCallbacks(suite.server)
}
@ -448,7 +484,6 @@ func (suite *ServiceSuite) TestLoadCollectionWithUserSpecifiedReplicaMode() {
collectionID := suite.collections[0]
// Mock broker methods using mockey
mockey.Mock(mockey.GetMethod(suite.broker, "DescribeCollection")).Return(nil, nil).Build()
suite.expectGetRecoverInfo(collectionID)
// Test when user specifies replica number - should set IsUserSpecifiedReplicaMode to true
@ -474,7 +509,6 @@ func (suite *ServiceSuite) TestLoadCollectionWithoutUserSpecifiedReplicaMode() {
collectionID := suite.collections[0]
// Mock broker methods using mockey
mockey.Mock(mockey.GetMethod(suite.broker, "DescribeCollection")).Return(nil, nil).Build()
suite.expectGetRecoverInfo(collectionID)
// Test when user doesn't specify replica number - should not set IsUserSpecifiedReplicaMode
@ -1077,7 +1111,6 @@ func (suite *ServiceSuite) TestLoadPartitionsWithUserSpecifiedReplicaMode() {
partitionIDs := suite.partitions[collectionID]
// Mock broker methods using mockey
mockey.Mock(mockey.GetMethod(suite.broker, "DescribeCollection")).Return(nil, nil).Build()
suite.expectGetRecoverInfo(collectionID)
// Test when user specifies replica number - should set IsUserSpecifiedReplicaMode to true
@ -1105,7 +1138,6 @@ func (suite *ServiceSuite) TestLoadPartitionsWithoutUserSpecifiedReplicaMode() {
partitionIDs := suite.partitions[collectionID]
// Mock broker methods using mockey
mockey.Mock(mockey.GetMethod(suite.broker, "DescribeCollection")).Return(nil, nil).Build()
suite.expectGetRecoverInfo(collectionID)
// Test when user doesn't specify replica number - should not set IsUserSpecifiedReplicaMode
@ -1955,21 +1987,8 @@ func (suite *ServiceSuite) loadAll() {
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection],
}
job := job.NewLoadCollectionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.broker,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
false,
)
suite.jobScheduler.Add(job)
err := job.Wait()
suite.NoError(err)
resp, err := suite.server.LoadCollection(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(ctx, collection))
suite.True(suite.meta.Exist(ctx, collection))
suite.NotNil(suite.meta.GetCollection(ctx, collection))
@ -1980,21 +1999,8 @@ func (suite *ServiceSuite) loadAll() {
PartitionIDs: suite.partitions[collection],
ReplicaNumber: suite.replicaNumber[collection],
}
job := job.NewLoadPartitionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.broker,
suite.targetMgr,
suite.targetObserver,
suite.collectionObserver,
suite.nodeMgr,
false,
)
suite.jobScheduler.Add(job)
err := job.Wait()
suite.NoError(err)
resp, err := suite.server.LoadPartitions(ctx, req)
suite.Require().NoError(merr.CheckRPCCall(resp, err))
suite.EqualValues(suite.replicaNumber[collection], suite.meta.GetReplicaNumber(ctx, collection))
suite.True(suite.meta.Exist(ctx, collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(ctx, collection))
@ -2059,6 +2065,12 @@ func (suite *ServiceSuite) assertSegments(collection int64, segments []*querypb.
return true
}
func (suite *ServiceSuite) expectGetRecoverInfoForAllCollections() {
for _, collection := range suite.collections {
suite.expectGetRecoverInfo(collection)
}
}
func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
vChannels := []*datapb.VchannelInfo{}
@ -2120,6 +2132,7 @@ func (suite *ServiceSuite) updateChannelDist(ctx context.Context, collection int
segments := lo.Flatten(lo.Values(suite.segments[collection]))
replicas := suite.meta.ReplicaManager.GetByCollection(ctx, collection)
targetVersion := suite.targetMgr.GetCollectionTargetVersion(ctx, collection, meta.CurrentTargetFirst)
for _, replica := range replicas {
i := 0
for _, node := range suite.sortInt64(replica.GetNodes()) {
@ -2139,6 +2152,7 @@ func (suite *ServiceSuite) updateChannelDist(ctx context.Context, collection int
Version: time.Now().Unix(),
}
}),
TargetVersion: targetVersion,
Status: &querypb.LeaderViewStatus{
Serviceable: true,
},
@ -2152,6 +2166,16 @@ func (suite *ServiceSuite) updateChannelDist(ctx context.Context, collection int
}
}
func (suite *ServiceSuite) releaseSegmentDist(nodeID int64) {
suite.dist.SegmentDistManager.Update(nodeID)
}
func (suite *ServiceSuite) releaseAllChannelDist() {
for _, node := range suite.nodes {
suite.dist.ChannelDistManager.Update(node)
}
}
func (suite *ServiceSuite) sortInt64(ints []int64) []int64 {
sort.Slice(ints, func(i int, j int) bool {
return ints[i] < ints[j]
@ -2219,6 +2243,8 @@ func (suite *ServiceSuite) fetchHeartbeats(time time.Time) {
func (suite *ServiceSuite) TearDownTest() {
suite.targetObserver.Stop()
suite.collectionObserver.Stop()
suite.jobScheduler.Stop()
}
func TestService(t *testing.T) {

View File

@ -21,7 +21,6 @@ import (
"fmt"
"strings"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
@ -109,8 +108,7 @@ func RecoverAllCollection(m *meta.Meta) {
func AssignReplica(ctx context.Context, m *meta.Meta, resourceGroups []string, replicaNumber int32, checkNodeNum bool) (map[string]int, error) {
if len(resourceGroups) != 0 && len(resourceGroups) != 1 && len(resourceGroups) != int(replicaNumber) {
return nil, errors.Errorf(
"replica=[%d] resource group=[%s], resource group num can only be 0, 1 or same as replica number", replicaNumber, strings.Join(resourceGroups, ","))
return nil, merr.WrapErrParameterInvalidMsg("replica=[%d] resource group=[%s], resource group num can only be 0, 1 or same as replica number", replicaNumber, strings.Join(resourceGroups, ","))
}
if streamingutil.IsStreamingServiceEnabled() {
@ -158,6 +156,19 @@ func AssignReplica(ctx context.Context, m *meta.Meta, resourceGroups []string, r
return replicaNumInRG, nil
}
// SpawnReplicasWithReplicaConfig spawns replicas with replica config.
func SpawnReplicasWithReplicaConfig(ctx context.Context, m *meta.Meta, params meta.SpawnWithReplicaConfigParams) ([]*meta.Replica, error) {
replicas, err := m.ReplicaManager.SpawnWithReplicaConfig(ctx, params)
if err != nil {
return nil, err
}
RecoverReplicaOfCollection(ctx, m, params.CollectionID)
if streamingutil.IsStreamingServiceEnabled() {
m.RecoverSQNodesInCollection(ctx, params.CollectionID, snmanager.StaticStreamingNodeManager.GetStreamingQueryNodeIDs())
}
return replicas, nil
}
// SpawnReplicasWithRG spawns replicas in rgs one by one for given collection.
func SpawnReplicasWithRG(ctx context.Context, m *meta.Meta, collection int64, resourceGroups []string,
replicaNumber int32, channels []string, loadPriority commonpb.LoadPriority,

View File

@ -157,6 +157,15 @@ func toKeyDataPairs(m map[string][]byte) []*commonpb.KeyDataPair {
return ret
}
// toMap converts []*commonpb.KeyDataPair to map[string][]byte.
func toMap(pairs []*commonpb.KeyDataPair) map[string][]byte {
m := make(map[string][]byte, len(pairs))
for _, pair := range pairs {
m[pair.Key] = pair.Data
}
return m
}
func (b *ServerBroker) WatchChannels(ctx context.Context, info *watchInfo) error {
log.Ctx(ctx).Info("watching channels", zap.Uint64("ts", info.ts), zap.Int64("collection", info.collectionID), zap.Strings("vChannels", info.vChannels))

View File

@ -40,6 +40,7 @@ func TestCheckGeneralCapacity(t *testing.T) {
catalog.EXPECT().ListCollections(mock.Anything, mock.Anything, mock.Anything).Return(nil, nil)
catalog.EXPECT().ListAliases(mock.Anything, mock.Anything, mock.Anything).Return(nil, nil)
catalog.EXPECT().CreateDatabase(mock.Anything, mock.Anything, mock.Anything).Return(nil)
catalog.EXPECT().AlterCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
allocator := mocktso.NewAllocator(t)
allocator.EXPECT().GenerateTSO(mock.Anything).Return(1000, nil)
@ -64,7 +65,7 @@ func TestCheckGeneralCapacity(t *testing.T) {
assert.NoError(t, err)
err = meta.AddCollection(ctx, &model.Collection{
CollectionID: 1,
State: pb.CollectionState_CollectionCreating,
State: pb.CollectionState_CollectionCreated,
ShardsNum: 256,
Partitions: []*model.Partition{
{PartitionID: 100, State: pb.PartitionState_PartitionCreated},
@ -73,19 +74,11 @@ func TestCheckGeneralCapacity(t *testing.T) {
})
assert.NoError(t, err)
assert.Equal(t, 0, meta.GetGeneralCount(ctx))
err = checkGeneralCapacity(ctx, 1, 2, 256, core)
assert.NoError(t, err)
catalog.EXPECT().AlterCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
err = meta.ChangeCollectionState(ctx, 1, pb.CollectionState_CollectionCreated, typeutil.MaxTimestamp)
assert.NoError(t, err)
assert.Equal(t, 512, meta.GetGeneralCount(ctx))
err = checkGeneralCapacity(ctx, 1, 1, 1, core)
assert.Error(t, err)
err = meta.ChangeCollectionState(ctx, 1, pb.CollectionState_CollectionDropping, typeutil.MaxTimestamp)
err = meta.DropCollection(ctx, 1, typeutil.MaxTimestamp)
assert.NoError(t, err)
assert.Equal(t, 0, meta.GetGeneralCount(ctx))

View File

@ -25,56 +25,34 @@ import (
"github.com/twpayne/go-geom/encoding/wkb"
"github.com/twpayne/go-geom/encoding/wkt"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
ms "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type collectionChannels struct {
virtualChannels []string
physicalChannels []string
}
type createCollectionTask struct {
baseTask
Req *milvuspb.CreateCollectionRequest
schema *schemapb.CollectionSchema
collID UniqueID
partIDs []UniqueID
channels collectionChannels
dbID UniqueID
partitionNames []string
dbProperties []*commonpb.KeyValuePair
*Core
Req *milvuspb.CreateCollectionRequest
header *message.CreateCollectionMessageHeader
body *message.CreateCollectionRequest
}
func (t *createCollectionTask) validate(ctx context.Context) error {
if t.Req == nil {
return errors.New("empty requests")
}
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_CreateCollection); err != nil {
return err
}
Params := paramtable.Get()
// 1. check shard number
shardsNum := t.Req.GetShardsNum()
@ -94,7 +72,7 @@ func (t *createCollectionTask) validate(ctx context.Context) error {
}
// 2. check db-collection capacity
db2CollIDs := t.core.meta.ListAllAvailCollections(t.ctx)
db2CollIDs := t.meta.ListAllAvailCollections(ctx)
if err := t.checkMaxCollectionsPerDB(ctx, db2CollIDs); err != nil {
return err
}
@ -116,18 +94,20 @@ func (t *createCollectionTask) validate(ctx context.Context) error {
if t.Req.GetNumPartitions() > 0 {
newPartNum = t.Req.GetNumPartitions()
}
return checkGeneralCapacity(t.ctx, 1, newPartNum, t.Req.GetShardsNum(), t.core)
return checkGeneralCapacity(ctx, 1, newPartNum, t.Req.GetShardsNum(), t.Core)
}
// checkMaxCollectionsPerDB DB properties take precedence over quota configurations for max collections.
func (t *createCollectionTask) checkMaxCollectionsPerDB(ctx context.Context, db2CollIDs map[int64][]int64) error {
collIDs, ok := db2CollIDs[t.dbID]
Params := paramtable.Get()
collIDs, ok := db2CollIDs[t.header.DbId]
if !ok {
log.Ctx(ctx).Warn("can not found DB ID", zap.String("collection", t.Req.GetCollectionName()), zap.String("dbName", t.Req.GetDbName()))
return merr.WrapErrDatabaseNotFound(t.Req.GetDbName(), "failed to create collection")
}
db, err := t.core.meta.GetDatabaseByName(t.ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
db, err := t.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
if err != nil {
log.Ctx(ctx).Warn("can not found DB ID", zap.String("collection", t.Req.GetCollectionName()), zap.String("dbName", t.Req.GetDbName()))
return merr.WrapErrDatabaseNotFound(t.Req.GetDbName(), "failed to create collection")
@ -271,6 +251,24 @@ func (t *createCollectionTask) appendDynamicField(ctx context.Context, schema *s
}
}
func (t *createCollectionTask) appendConsistecyLevel() {
if ok, _ := getConsistencyLevel(t.Req.Properties...); ok {
return
}
for _, p := range t.Req.Properties {
if p.GetKey() == common.ConsistencyLevel {
// if there's already a consistency level, overwrite it.
p.Value = strconv.Itoa(int(t.Req.ConsistencyLevel))
return
}
}
// append consistency level into schema properties
t.Req.Properties = append(t.Req.Properties, &commonpb.KeyValuePair{
Key: common.ConsistencyLevel,
Value: strconv.Itoa(int(t.Req.ConsistencyLevel)),
})
}
func (t *createCollectionTask) handleNamespaceField(ctx context.Context, schema *schemapb.CollectionSchema) error {
if !Params.CommonCfg.EnableNamespace.GetAsBool() {
return nil
@ -310,7 +308,7 @@ func (t *createCollectionTask) handleNamespaceField(ctx context.Context, schema
{Key: common.MaxLengthKey, Value: fmt.Sprintf("%d", paramtable.Get().ProxyCfg.MaxVarCharLength.GetAsInt())},
},
})
schema.Properties = append(schema.Properties, &commonpb.KeyValuePair{
t.Req.Properties = append(t.Req.Properties, &commonpb.KeyValuePair{
Key: common.PartitionKeyIsolationKey,
Value: "true",
})
@ -347,48 +345,40 @@ func (t *createCollectionTask) appendSysFields(schema *schemapb.CollectionSchema
}
func (t *createCollectionTask) prepareSchema(ctx context.Context) error {
var schema schemapb.CollectionSchema
if err := proto.Unmarshal(t.Req.GetSchema(), &schema); err != nil {
if err := t.validateSchema(ctx, t.body.CollectionSchema); err != nil {
return err
}
if err := t.validateSchema(ctx, &schema); err != nil {
return err
}
t.appendDynamicField(ctx, &schema)
if err := t.handleNamespaceField(ctx, &schema); err != nil {
t.appendConsistecyLevel()
t.appendDynamicField(ctx, t.body.CollectionSchema)
if err := t.handleNamespaceField(ctx, t.body.CollectionSchema); err != nil {
return err
}
if err := t.assignFieldAndFunctionID(&schema); err != nil {
if err := t.assignFieldAndFunctionID(t.body.CollectionSchema); err != nil {
return err
}
// Set properties for persistent
schema.Properties = t.Req.GetProperties()
t.appendSysFields(&schema)
t.schema = &schema
t.body.CollectionSchema.Properties = t.Req.GetProperties()
t.appendSysFields(t.body.CollectionSchema)
return nil
}
func (t *createCollectionTask) assignShardsNum() {
if t.Req.GetShardsNum() <= 0 {
t.Req.ShardsNum = common.DefaultShardsNum
}
}
func (t *createCollectionTask) assignCollectionID() error {
var err error
t.collID, err = t.core.idAllocator.AllocOne()
t.header.CollectionId, err = t.idAllocator.AllocOne()
t.body.CollectionID = t.header.CollectionId
return err
}
func (t *createCollectionTask) assignPartitionIDs(ctx context.Context) error {
t.partitionNames = make([]string, 0)
Params := paramtable.Get()
partitionNames := make([]string, 0, t.Req.GetNumPartitions())
defaultPartitionName := Params.CommonCfg.DefaultPartitionName.GetValue()
_, err := typeutil.GetPartitionKeyFieldSchema(t.schema)
if err == nil {
if _, err := typeutil.GetPartitionKeyFieldSchema(t.body.CollectionSchema); err == nil {
// only when enabling partition key mode, we allow to create multiple partitions.
partitionNums := t.Req.GetNumPartitions()
// double check, default num of physical partitions should be greater than 0
if partitionNums <= 0 {
@ -402,68 +392,58 @@ func (t *createCollectionTask) assignPartitionIDs(ctx context.Context) error {
}
for i := int64(0); i < partitionNums; i++ {
t.partitionNames = append(t.partitionNames, fmt.Sprintf("%s_%d", defaultPartitionName, i))
partitionNames = append(partitionNames, fmt.Sprintf("%s_%d", defaultPartitionName, i))
}
} else {
// compatible with old versions <= 2.2.8
t.partitionNames = append(t.partitionNames, defaultPartitionName)
partitionNames = append(partitionNames, defaultPartitionName)
}
t.partIDs = make([]UniqueID, len(t.partitionNames))
start, end, err := t.core.idAllocator.Alloc(uint32(len(t.partitionNames)))
// allocate partition ids
start, end, err := t.idAllocator.Alloc(uint32(len(partitionNames)))
if err != nil {
return err
}
t.header.PartitionIds = make([]int64, len(partitionNames))
t.body.PartitionIDs = make([]int64, len(partitionNames))
for i := start; i < end; i++ {
t.header.PartitionIds[i-start] = i
t.body.PartitionIDs[i-start] = i
}
t.body.PartitionNames = partitionNames
log.Ctx(ctx).Info("assign partitions when create collection",
zap.String("collectionName", t.Req.GetCollectionName()),
zap.Int64s("partitionIds", t.header.PartitionIds),
zap.Strings("partitionNames", t.body.PartitionNames))
return nil
}
func (t *createCollectionTask) assignChannels(ctx context.Context) error {
vchannels, err := snmanager.StaticStreamingNodeManager.AllocVirtualChannels(ctx, balancer.AllocVChannelParam{
CollectionID: t.header.GetCollectionId(),
Num: int(t.Req.GetShardsNum()),
})
if err != nil {
return err
}
for i := start; i < end; i++ {
t.partIDs[i-start] = i
}
log.Ctx(ctx).Info("assign partitions when create collection",
zap.String("collectionName", t.Req.GetCollectionName()),
zap.Strings("partitionNames", t.partitionNames))
return nil
}
func (t *createCollectionTask) assignChannels() error {
vchanNames := make([]string, t.Req.GetShardsNum())
// physical channel names
chanNames := t.core.chanTimeTick.getDmlChannelNames(int(t.Req.GetShardsNum()))
if int32(len(chanNames)) < t.Req.GetShardsNum() {
return fmt.Errorf("no enough channels, want: %d, got: %d", t.Req.GetShardsNum(), len(chanNames))
}
shardNum := int(t.Req.GetShardsNum())
for i := 0; i < shardNum; i++ {
vchanNames[i] = funcutil.GetVirtualChannel(chanNames[i], t.collID, i)
}
t.channels = collectionChannels{
virtualChannels: vchanNames,
physicalChannels: chanNames,
for _, vchannel := range vchannels {
t.body.PhysicalChannelNames = append(t.body.PhysicalChannelNames, funcutil.ToPhysicalChannel(vchannel))
t.body.VirtualChannelNames = append(t.body.VirtualChannelNames, vchannel)
}
return nil
}
func (t *createCollectionTask) Prepare(ctx context.Context) error {
db, err := t.core.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
t.body.Base = &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateCollection,
}
db, err := t.meta.GetDatabaseByName(ctx, t.Req.GetDbName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
t.dbID = db.ID
dbReplicateID, _ := common.GetReplicateID(db.Properties)
if dbReplicateID != "" {
reqProperties := make([]*commonpb.KeyValuePair, 0, len(t.Req.Properties))
for _, prop := range t.Req.Properties {
if prop.Key == common.ReplicateIDKey {
continue
}
reqProperties = append(reqProperties, prop)
}
t.Req.Properties = reqProperties
}
t.dbProperties = db.Properties
// set collection timezone
properties := t.Req.GetProperties()
ok, _ := getDefaultTimezoneVal(properties...)
@ -476,10 +456,12 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error {
t.Req.Properties = append(properties, timezoneKV)
}
if hookutil.GetEzPropByDBProperties(t.dbProperties) != nil {
t.Req.Properties = append(t.Req.Properties, hookutil.GetEzPropByDBProperties(t.dbProperties))
if hookutil.GetEzPropByDBProperties(db.Properties) != nil {
t.Req.Properties = append(t.Req.Properties, hookutil.GetEzPropByDBProperties(db.Properties))
}
t.header.DbId = db.ID
t.body.DbID = t.header.DbId
if err := t.validate(ctx); err != nil {
return err
}
@ -488,8 +470,6 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error {
return err
}
t.assignShardsNum()
if err := t.assignCollectionID(); err != nil {
return err
}
@ -498,273 +478,29 @@ func (t *createCollectionTask) Prepare(ctx context.Context) error {
return err
}
return t.assignChannels()
}
func (t *createCollectionTask) genCreateCollectionMsg(ctx context.Context, ts uint64) *ms.MsgPack {
msgPack := ms.MsgPack{}
msg := &ms.CreateCollectionMsg{
BaseMsg: ms.BaseMsg{
Ctx: ctx,
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
},
CreateCollectionRequest: t.genCreateCollectionRequest(),
}
msgPack.Msgs = append(msgPack.Msgs, msg)
return &msgPack
}
func (t *createCollectionTask) genCreateCollectionRequest() *msgpb.CreateCollectionRequest {
collectionID := t.collID
partitionIDs := t.partIDs
// error won't happen here.
marshaledSchema, _ := proto.Marshal(t.schema)
pChannels := t.channels.physicalChannels
vChannels := t.channels.virtualChannels
return &msgpb.CreateCollectionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_CreateCollection),
commonpbutil.WithTimeStamp(t.ts),
),
CollectionID: collectionID,
PartitionIDs: partitionIDs,
Schema: marshaledSchema,
VirtualChannelNames: vChannels,
PhysicalChannelNames: pChannels,
}
}
func (t *createCollectionTask) addChannelsAndGetStartPositions(ctx context.Context, ts uint64) (map[string][]byte, error) {
t.core.chanTimeTick.addDmlChannels(t.channels.physicalChannels...)
if streamingutil.IsStreamingServiceEnabled() {
return t.broadcastCreateCollectionMsgIntoStreamingService(ctx, ts)
}
msg := t.genCreateCollectionMsg(ctx, ts)
return t.core.chanTimeTick.broadcastMarkDmlChannels(t.channels.physicalChannels, msg)
}
func (t *createCollectionTask) broadcastCreateCollectionMsgIntoStreamingService(ctx context.Context, ts uint64) (map[string][]byte, error) {
notifier := snmanager.NewStreamingReadyNotifier()
if err := snmanager.StaticStreamingNodeManager.RegisterStreamingEnabledListener(ctx, notifier); err != nil {
return nil, err
}
if !notifier.IsReady() {
// streaming service is not ready, so we send it into msgstream.
defer notifier.Release()
msg := t.genCreateCollectionMsg(ctx, ts)
return t.core.chanTimeTick.broadcastMarkDmlChannels(t.channels.physicalChannels, msg)
}
// streaming service is ready, so we release the ready notifier and send it into streaming service.
notifier.Release()
req := t.genCreateCollectionRequest()
// dispatch the createCollectionMsg into all vchannel.
msgs := make([]message.MutableMessage, 0, len(req.VirtualChannelNames))
for _, vchannel := range req.VirtualChannelNames {
msg, err := message.NewCreateCollectionMessageBuilderV1().
WithVChannel(vchannel).
WithHeader(&message.CreateCollectionMessageHeader{
CollectionId: req.CollectionID,
PartitionIds: req.GetPartitionIDs(),
}).
WithBody(req).
BuildMutable()
if err != nil {
return nil, err
}
msgs = append(msgs, msg)
}
// send the createCollectionMsg into streaming service.
// ts is used as initial checkpoint at datacoord,
// it must be set as barrier time tick.
// The timetick of create message in wal must be greater than ts, to avoid data read loss at read side.
resps := streaming.WAL().AppendMessagesWithOption(ctx, streaming.AppendOption{
BarrierTimeTick: ts,
}, msgs...)
if err := resps.UnwrapFirstError(); err != nil {
return nil, err
}
// make the old message stream serialized id.
startPositions := make(map[string][]byte)
for idx, resp := range resps.Responses {
// The key is pchannel here
startPositions[req.PhysicalChannelNames[idx]] = adaptor.MustGetMQWrapperIDFromMessage(resp.AppendResult.MessageID).Serialize()
}
return startPositions, nil
}
func (t *createCollectionTask) getCreateTs(ctx context.Context) (uint64, error) {
replicateInfo := t.Req.GetBase().GetReplicateInfo()
if !replicateInfo.GetIsReplicate() {
return t.GetTs(), nil
}
if replicateInfo.GetMsgTimestamp() == 0 {
log.Ctx(ctx).Warn("the cdc timestamp is not set in the request for the backup instance")
return 0, merr.WrapErrParameterInvalidMsg("the cdc timestamp is not set in the request for the backup instance")
}
return replicateInfo.GetMsgTimestamp(), nil
}
func (t *createCollectionTask) Execute(ctx context.Context) error {
collID := t.collID
partIDs := t.partIDs
ts, err := t.getCreateTs(ctx)
if err != nil {
if err := t.assignChannels(ctx); err != nil {
return err
}
vchanNames := t.channels.virtualChannels
chanNames := t.channels.physicalChannels
partitions := make([]*model.Partition, len(partIDs))
for i, partID := range partIDs {
partitions[i] = &model.Partition{
PartitionID: partID,
PartitionName: t.partitionNames[i],
PartitionCreatedTimestamp: ts,
CollectionID: collID,
State: pb.PartitionState_PartitionCreated,
}
}
ConsistencyLevel := t.Req.ConsistencyLevel
if ok, level := getConsistencyLevel(t.Req.Properties...); ok {
ConsistencyLevel = level
}
collInfo := model.Collection{
CollectionID: collID,
DBID: t.dbID,
Name: t.schema.Name,
DBName: t.Req.GetDbName(),
Description: t.schema.Description,
AutoID: t.schema.AutoID,
Fields: model.UnmarshalFieldModels(t.schema.Fields),
StructArrayFields: model.UnmarshalStructArrayFieldModels(t.schema.StructArrayFields),
Functions: model.UnmarshalFunctionModels(t.schema.Functions),
VirtualChannelNames: vchanNames,
PhysicalChannelNames: chanNames,
ShardsNum: t.Req.ShardsNum,
ConsistencyLevel: ConsistencyLevel,
CreateTime: ts,
State: pb.CollectionState_CollectionCreating,
Partitions: partitions,
Properties: t.Req.Properties,
EnableDynamicField: t.schema.EnableDynamicField,
UpdateTimestamp: ts,
}
return t.validateIfCollectionExists(ctx)
}
func (t *createCollectionTask) validateIfCollectionExists(ctx context.Context) error {
// Check if the collection name duplicates an alias.
_, err = t.core.meta.DescribeAlias(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
if err == nil {
if _, err := t.meta.DescribeAlias(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp); err == nil {
err2 := fmt.Errorf("collection name [%s] conflicts with an existing alias, please choose a unique name", t.Req.GetCollectionName())
log.Ctx(ctx).Warn("create collection failed", zap.String("database", t.Req.GetDbName()), zap.Error(err2))
return err2
}
// We cannot check the idempotency inside meta table when adding collection, since we'll execute duplicate steps
// if add collection successfully due to idempotency check. Some steps may be risky to be duplicate executed if they
// are not promised idempotent.
clone := collInfo.Clone()
// need double check in meta table if we can't promise the sequence execution.
existedCollInfo, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
// Check if the collection already exists.
existedCollInfo, err := t.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
if err == nil {
equal := existedCollInfo.Equal(*clone)
if !equal {
newCollInfo := newCollectionModel(t.header, t.body, 0)
if equal := existedCollInfo.Equal(*newCollInfo); !equal {
return fmt.Errorf("create duplicate collection with different parameters, collection: %s", t.Req.GetCollectionName())
}
// make creating collection idempotent.
log.Ctx(ctx).Warn("add duplicate collection", zap.String("collection", t.Req.GetCollectionName()), zap.Uint64("ts", ts))
return nil
return errIgnoredCreateCollection
}
log.Ctx(ctx).Info("check collection existence", zap.String("collection", t.Req.GetCollectionName()), zap.Error(err))
// TODO: The create collection is not idempotent for other component, such as wal.
// we need to make the create collection operation must success after some persistent operation, refactor it in future.
startPositions, err := t.addChannelsAndGetStartPositions(ctx, ts)
if err != nil {
// ugly here, since we must get start positions first.
t.core.chanTimeTick.removeDmlChannels(t.channels.physicalChannels...)
return err
}
collInfo.StartPositions = toKeyDataPairs(startPositions)
return executeCreateCollectionTaskSteps(ctx, t.core, &collInfo, t.Req.GetDbName(), t.dbProperties, ts)
}
func (t *createCollectionTask) GetLockerKey() LockerKey {
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(t.Req.GetDbName(), false),
NewCollectionLockerKey(strconv.FormatInt(t.collID, 10), true),
)
}
func executeCreateCollectionTaskSteps(ctx context.Context,
core *Core,
col *model.Collection,
dbName string,
dbProperties []*commonpb.KeyValuePair,
ts Timestamp,
) error {
undoTask := newBaseUndoTask(core.stepExecutor)
collID := col.CollectionID
undoTask.AddStep(&expireCacheStep{
baseStep: baseStep{core: core},
dbName: dbName,
collectionNames: []string{col.Name},
collectionID: collID,
ts: ts,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_DropCollection)},
}, &nullStep{})
undoTask.AddStep(&nullStep{}, &removeDmlChannelsStep{
baseStep: baseStep{core: core},
pChannels: col.PhysicalChannelNames,
}) // remove dml channels if any error occurs.
undoTask.AddStep(&addCollectionMetaStep{
baseStep: baseStep{core: core},
coll: col,
}, &deleteCollectionMetaStep{
baseStep: baseStep{core: core},
collectionID: collID,
// When we undo createCollectionTask, this ts may be less than the ts when unwatch channels.
ts: ts,
})
// serve for this case: watching channels succeed in datacoord but failed due to network failure.
undoTask.AddStep(&nullStep{}, &unwatchChannelsStep{
baseStep: baseStep{core: core},
collectionID: collID,
channels: collectionChannels{
virtualChannels: col.VirtualChannelNames,
physicalChannels: col.PhysicalChannelNames,
},
isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(),
})
undoTask.AddStep(&watchChannelsStep{
baseStep: baseStep{core: core},
info: &watchInfo{
ts: ts,
collectionID: collID,
vChannels: col.VirtualChannelNames,
startPositions: col.StartPositions,
schema: &schemapb.CollectionSchema{
Name: col.Name,
DbName: col.DBName,
Description: col.Description,
AutoID: col.AutoID,
Fields: model.MarshalFieldModels(col.Fields),
StructArrayFields: model.MarshalStructArrayFieldModels(col.StructArrayFields),
Properties: col.Properties,
Functions: model.MarshalFunctionModels(col.Functions),
},
dbProperties: dbProperties,
},
}, &nullStep{})
undoTask.AddStep(&changeCollectionStateStep{
baseStep: baseStep{core: core},
collectionID: collID,
state: pb.CollectionState_CollectionCreated,
ts: ts,
}, &nullStep{}) // We'll remove the whole collection anyway.
return undoTask.Execute(ctx)
return nil
}

File diff suppressed because it is too large Load Diff

View File

@ -1,149 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/log"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
)
type createPartitionTask struct {
baseTask
Req *milvuspb.CreatePartitionRequest
collMeta *model.Collection
}
func (t *createPartitionTask) Prepare(ctx context.Context) error {
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_CreatePartition); err != nil {
return err
}
collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), t.GetTs())
if err != nil {
return err
}
t.collMeta = collMeta
return checkGeneralCapacity(ctx, 0, 1, 0, t.core)
}
func (t *createPartitionTask) Execute(ctx context.Context) error {
for _, partition := range t.collMeta.Partitions {
if partition.PartitionName == t.Req.GetPartitionName() {
log.Ctx(ctx).Warn("add duplicate partition", zap.String("collection", t.Req.GetCollectionName()), zap.String("partition", t.Req.GetPartitionName()), zap.Uint64("ts", t.GetTs()))
return nil
}
}
cfgMaxPartitionNum := Params.RootCoordCfg.MaxPartitionNum.GetAsInt()
if len(t.collMeta.Partitions) >= cfgMaxPartitionNum {
return fmt.Errorf("partition number (%d) exceeds max configuration (%d), collection: %s",
len(t.collMeta.Partitions), cfgMaxPartitionNum, t.collMeta.Name)
}
partID, err := t.core.idAllocator.AllocOne()
if err != nil {
return err
}
partition := &model.Partition{
PartitionID: partID,
PartitionName: t.Req.GetPartitionName(),
PartitionCreatedTimestamp: t.GetTs(),
Extra: nil,
CollectionID: t.collMeta.CollectionID,
State: pb.PartitionState_PartitionCreating,
}
return executeCreatePartitionTaskSteps(ctx, t.core, partition, t.collMeta, t.Req.GetDbName(), t.GetTs())
}
func (t *createPartitionTask) GetLockerKey() LockerKey {
collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0)
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(t.Req.GetDbName(), false),
NewCollectionLockerKey(collection, true),
)
}
func executeCreatePartitionTaskSteps(ctx context.Context,
core *Core,
partition *model.Partition,
col *model.Collection,
dbName string,
ts Timestamp,
) error {
undoTask := newBaseUndoTask(core.stepExecutor)
partID := partition.PartitionID
collectionID := partition.CollectionID
undoTask.AddStep(&expireCacheStep{
baseStep: baseStep{core: core},
dbName: dbName,
collectionNames: []string{col.Name},
collectionID: collectionID,
partitionName: partition.PartitionName,
ts: ts,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_CreatePartition)},
}, &nullStep{})
undoTask.AddStep(&addPartitionMetaStep{
baseStep: baseStep{core: core},
partition: partition,
}, &removePartitionMetaStep{
baseStep: baseStep{core: core},
dbID: col.DBID,
collectionID: partition.CollectionID,
partitionID: partition.PartitionID,
ts: ts,
})
if streamingutil.IsStreamingServiceEnabled() {
if err := snmanager.StaticStreamingNodeManager.CheckIfStreamingServiceReady(ctx); err == nil {
undoTask.AddStep(&broadcastCreatePartitionMsgStep{
baseStep: baseStep{core: core},
vchannels: col.VirtualChannelNames,
partition: partition,
ts: ts,
}, &nullStep{})
}
}
undoTask.AddStep(&nullStep{}, &releasePartitionsStep{
baseStep: baseStep{core: core},
collectionID: col.CollectionID,
partitionIDs: []int64{partID},
})
undoTask.AddStep(&changePartitionStateStep{
baseStep: baseStep{core: core},
collectionID: col.CollectionID,
partitionID: partID,
state: pb.PartitionState_PartitionCreated,
ts: ts,
}, &nullStep{})
return undoTask.Execute(ctx)
}

View File

@ -1,173 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
)
func Test_createPartitionTask_Prepare(t *testing.T) {
t.Run("invalid msg type", func(t *testing.T) {
task := &createPartitionTask{
Req: &milvuspb.CreatePartitionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection}},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("failed to get collection meta", func(t *testing.T) {
core := newTestCore(withInvalidMeta())
task := &createPartitionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.CreatePartitionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition}},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName}
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
meta.EXPECT().GetGeneralCount(mock.Anything).Return(0)
core := newTestCore(withMeta(meta))
task := &createPartitionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.CreatePartitionRequest{Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_CreatePartition}},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
assert.True(t, coll.Equal(*task.collMeta))
})
}
func Test_createPartitionTask_Execute(t *testing.T) {
t.Run("create duplicate partition", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{{PartitionName: partitionName}}}
task := &createPartitionTask{
collMeta: coll,
Req: &milvuspb.CreatePartitionRequest{CollectionName: collectionName, PartitionName: partitionName},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
})
t.Run("create too many partitions", func(t *testing.T) {
cfgMaxPartitionNum := Params.RootCoordCfg.MaxPartitionNum.GetAsInt()
partitions := make([]*model.Partition, 0, cfgMaxPartitionNum)
for i := 0; i < cfgMaxPartitionNum; i++ {
partitions = append(partitions, &model.Partition{})
}
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: partitions}
task := &createPartitionTask{
collMeta: coll,
Req: &milvuspb.CreatePartitionRequest{CollectionName: collectionName, PartitionName: partitionName},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("failed to allocate partition id", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{}}
core := newTestCore(withInvalidIDAllocator())
task := &createPartitionTask{
baseTask: newBaseTask(context.Background(), core),
collMeta: coll,
Req: &milvuspb.CreatePartitionRequest{CollectionName: collectionName, PartitionName: partitionName},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("failed to expire cache", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{}}
core := newTestCore(withValidIDAllocator(), withInvalidProxyManager())
task := &createPartitionTask{
baseTask: newBaseTask(context.Background(), core),
collMeta: coll,
Req: &milvuspb.CreatePartitionRequest{CollectionName: collectionName, PartitionName: partitionName},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("failed to add partition meta", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{}}
core := newTestCore(withValidIDAllocator(), withValidProxyManager(), withInvalidMeta())
task := &createPartitionTask{
baseTask: newBaseTask(context.Background(), core),
collMeta: coll,
Req: &milvuspb.CreatePartitionRequest{CollectionName: collectionName, PartitionName: partitionName},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{}}
meta := newMockMetaTable()
meta.AddPartitionFunc = func(ctx context.Context, partition *model.Partition) error {
return nil
}
meta.ChangePartitionStateFunc = func(ctx context.Context, collectionID UniqueID, partitionID UniqueID, state etcdpb.PartitionState, ts Timestamp) error {
return nil
}
b := newMockBroker()
b.SyncNewCreatedPartitionFunc = func(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error {
return nil
}
core := newTestCore(withValidIDAllocator(), withValidProxyManager(), withMeta(meta), withBroker(b))
task := &createPartitionTask{
baseTask: newBaseTask(context.Background(), core),
collMeta: coll,
Req: &milvuspb.CreatePartitionRequest{CollectionName: collectionName, PartitionName: partitionName},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
})
}

View File

@ -36,6 +36,9 @@ func RegisterDDLCallbacks(core *Core) {
ddlCallback := &DDLCallback{
Core: core,
}
ddlCallback.registerCollectionCallbacks()
ddlCallback.registerPartitionCallbacks()
ddlCallback.registerRBACCallbacks()
ddlCallback.registerDatabaseCallbacks()
ddlCallback.registerAliasCallbacks()
@ -69,6 +72,18 @@ func (c *DDLCallback) registerAliasCallbacks() {
registry.RegisterDropAliasV2AckCallback(c.dropAliasV2AckCallback)
}
// registerCollectionCallbacks registers the collection callbacks.
func (c *DDLCallback) registerCollectionCallbacks() {
registry.RegisterCreateCollectionV1AckCallback(c.createCollectionV1AckCallback)
registry.RegisterDropCollectionV1AckCallback(c.dropCollectionV1AckCallback)
}
// registerPartitionCallbacks registers the partition callbacks.
func (c *DDLCallback) registerPartitionCallbacks() {
registry.RegisterCreatePartitionV1AckCallback(c.createPartitionV1AckCallback)
registry.RegisterDropPartitionV1AckCallback(c.dropPartitionV1AckCallback)
}
// DDLCallback is the callback of ddl.
type DDLCallback struct {
*Core
@ -138,3 +153,15 @@ func startBroadcastWithAlterAliasLock(ctx context.Context, dbName string, collec
}
return broadcaster, nil
}
// startBroadcastWithCollectionLock starts a broadcast with collection lock.
func startBroadcastWithCollectionLock(ctx context.Context, dbName string, collectionName string) (broadcaster.BroadcastAPI, error) {
broadcaster, err := broadcast.StartBroadcastWithResourceKeys(ctx,
message.NewSharedDBNameResourceKey(dbName),
message.NewExclusiveCollectionNameResourceKey(dbName, collectionName),
)
if err != nil {
return nil, errors.Wrap(err, "failed to start broadcast with collection lock")
}
return broadcaster, nil
}

View File

@ -21,62 +21,48 @@ import (
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksAliasDDL(t *testing.T) {
initStreamingSystem()
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
ss, err := rootcoord.NewSuffixSnapshot(catalogKV, rootcoord.SnapshotsSep, path, rootcoord.SnapshotPrefix)
require.NoError(t, err)
testDB := newNameDb()
collID2Meta := make(map[typeutil.UniqueID]*model.Collection)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{
catalog: rootcoord.NewCatalog(catalogKV, ss),
names: testDB,
aliases: newNameDb(),
dbName2Meta: make(map[string]*model.Database),
collID2Meta: collID2Meta,
}),
withValidProxyManager(),
withValidIDAllocator(),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
core := initStreamingSystemAndCore(t)
// create database and collection to test alias ddl.
status, err := core.CreateDatabase(context.Background(), &milvuspb.CreateDatabaseRequest{
DbName: "test",
})
require.NoError(t, merr.CheckRPCCall(status, err))
// TODO: after refactor create collection, we can use CreateCollection to create a collection directly.
testDB.insert("test", "test_collection", 1)
testDB.insert("test", "test_collection2", 2)
collID2Meta[1] = &model.Collection{
CollectionID: 1,
Name: "test_collection",
State: pb.CollectionState_CollectionCreated,
}
collID2Meta[2] = &model.Collection{
CollectionID: 2,
Name: "test_collection2",
State: pb.CollectionState_CollectionCreated,
testSchema := &schemapb.CollectionSchema{
Name: "test_collection",
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "field1",
DataType: schemapb.DataType_Int64,
},
},
}
schemaBytes, _ := proto.Marshal(testSchema)
status, err = core.CreateCollection(context.Background(), &milvuspb.CreateCollectionRequest{
DbName: "test",
CollectionName: "test_collection",
Schema: schemaBytes,
})
require.NoError(t, merr.CheckRPCCall(status, err))
testSchema.Name = "test_collection2"
schemaBytes, _ = proto.Marshal(testSchema)
status, err = core.CreateCollection(context.Background(), &milvuspb.CreateCollectionRequest{
DbName: "test",
CollectionName: "test_collection2",
Schema: schemaBytes,
})
require.NoError(t, merr.CheckRPCCall(status, err))
// create an alias with a not-exist database.
status, err = core.CreateAlias(context.Background(), &milvuspb.CreateAliasRequest{
@ -104,7 +90,7 @@ func TestDDLCallbacksAliasDDL(t *testing.T) {
coll, err := core.meta.GetCollectionByName(context.Background(), "test", "test_alias", typeutil.MaxTimestamp)
require.NoError(t, err)
require.Equal(t, int64(1), coll.CollectionID)
require.NotZero(t, coll.CollectionID)
require.Equal(t, "test_collection", coll.Name)
// create an alias already created on current collection should be ok.

View File

@ -0,0 +1,152 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/samber/lo"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksCollectionDDL(t *testing.T) {
core := initStreamingSystemAndCore(t)
ctx := context.Background()
dbName := "testDB" + funcutil.RandomString(10)
collectionName := "testCollection" + funcutil.RandomString(10)
partitionName := "testPartition" + funcutil.RandomString(10)
testSchema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "field1",
DataType: schemapb.DataType_Int64,
},
},
}
schemaBytes, err := proto.Marshal(testSchema)
require.NoError(t, err)
// drop a collection that db not exist should be ignored.
status, err := core.DropCollection(ctx, &milvuspb.DropCollectionRequest{
DbName: "notExistDB",
CollectionName: collectionName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
// drop a collection that collection not exist should be ignored.
status, err = core.DropCollection(ctx, &milvuspb.DropCollectionRequest{
DbName: dbName,
CollectionName: "notExistCollection",
})
require.NoError(t, merr.CheckRPCCall(status, err))
// create a collection that database not exist should return error.
status, err = core.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: "notExistDB",
CollectionName: collectionName,
Schema: schemaBytes,
})
require.Error(t, merr.CheckRPCCall(status, err))
// Test CreateCollection
// create a database and a collection.
status, err = core.CreateDatabase(ctx, &milvuspb.CreateDatabaseRequest{
DbName: dbName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
status, err = core.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: schemaBytes,
})
require.NoError(t, merr.CheckRPCCall(status, err))
coll, err := core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.Equal(t, coll.Name, collectionName)
// create a collection with same schema should be idempotent.
status, err = core.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: schemaBytes,
})
require.NoError(t, merr.CheckRPCCall(status, err))
// Test CreatePartition
status, err = core.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
coll, err = core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.Len(t, coll.Partitions, 2)
require.Contains(t, lo.Map(coll.Partitions, func(p *model.Partition, _ int) string { return p.PartitionName }), partitionName)
// create a partition with same name should be idempotent.
status, err = core.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
coll, err = core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.NoError(t, err)
require.Len(t, coll.Partitions, 2)
status, err = core.DropPartition(ctx, &milvuspb.DropPartitionRequest{
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
// drop a partition that partition not exist should be idempotent.
status, err = core.DropPartition(ctx, &milvuspb.DropPartitionRequest{
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
// Test DropCollection
// drop the collection should be ok.
status, err = core.DropCollection(ctx, &milvuspb.DropCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
_, err = core.meta.GetCollectionByName(ctx, dbName, collectionName, typeutil.MaxTimestamp)
require.Error(t, err)
// drop a dropped collection should be idempotent.
status, err = core.DropCollection(ctx, &milvuspb.DropCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
require.NoError(t, merr.CheckRPCCall(status, err))
}

View File

@ -0,0 +1,220 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/adaptor"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/ce"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func (c *Core) broadcastCreateCollectionV1(ctx context.Context, req *milvuspb.CreateCollectionRequest) error {
schema := &schemapb.CollectionSchema{}
if err := proto.Unmarshal(req.GetSchema(), schema); err != nil {
return err
}
if req.GetShardsNum() <= 0 {
req.ShardsNum = common.DefaultShardsNum
}
if _, err := typeutil.GetPartitionKeyFieldSchema(schema); err == nil {
if req.GetNumPartitions() <= 0 {
req.NumPartitions = common.DefaultPartitionsWithPartitionKey
}
} else {
// we only support to create one partition when partition key is not enabled.
req.NumPartitions = int64(1)
}
broadcaster, err := startBroadcastWithCollectionLock(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
// prepare and validate the create collection message.
createCollectionTask := createCollectionTask{
Core: c,
Req: req,
header: &message.CreateCollectionMessageHeader{},
body: &message.CreateCollectionRequest{
DbName: req.GetDbName(),
CollectionName: req.GetCollectionName(),
CollectionSchema: schema,
},
}
if err := createCollectionTask.Prepare(ctx); err != nil {
return err
}
// setup the broadcast virtual channels and control channel, then make a broadcast message.
broadcastChannel := make([]string, 0, createCollectionTask.Req.ShardsNum+1)
broadcastChannel = append(broadcastChannel, streaming.WAL().ControlChannel())
for i := 0; i < int(createCollectionTask.Req.ShardsNum); i++ {
broadcastChannel = append(broadcastChannel, createCollectionTask.body.VirtualChannelNames[i])
}
msg := message.NewCreateCollectionMessageBuilderV1().
WithHeader(createCollectionTask.header).
WithBody(createCollectionTask.body).
WithBroadcast(broadcastChannel,
message.NewSharedDBNameResourceKey(createCollectionTask.body.DbName),
message.NewExclusiveCollectionNameResourceKey(createCollectionTask.body.DbName, createCollectionTask.body.CollectionName),
).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
return err
}
return nil
}
func (c *DDLCallback) createCollectionV1AckCallback(ctx context.Context, result message.BroadcastResultCreateCollectionMessageV1) error {
msg := result.Message
header := msg.Header()
body := msg.MustBody()
for vchannel, result := range result.Results {
if !funcutil.IsControlChannel(vchannel) {
// create shard info when virtual channel is created.
if err := c.createCollectionShard(ctx, header, body, vchannel, result); err != nil {
return errors.Wrap(err, "failed to create collection shard")
}
}
}
newCollInfo := newCollectionModelWithMessage(header, body, result)
if err := c.meta.AddCollection(ctx, newCollInfo); err != nil {
return errors.Wrap(err, "failed to add collection to meta table")
}
return c.ExpireCaches(ctx, ce.NewBuilder().WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(body.DbName),
ce.OptLPCMCollectionName(body.CollectionName),
ce.OptLPCMCollectionID(header.CollectionId),
ce.OptLPCMMsgType(commonpb.MsgType_DropCollection)),
newCollInfo.UpdateTimestamp,
)
}
func (c *DDLCallback) createCollectionShard(ctx context.Context, header *message.CreateCollectionMessageHeader, body *message.CreateCollectionRequest, vchannel string, appendResult *message.AppendResult) error {
// TODO: redundant channel watch by now, remove it in future.
startPosition := adaptor.MustGetMQWrapperIDFromMessage(appendResult.MessageID).Serialize()
// semantically, we should use the last confirmed message id to setup the start position.
// same as following `newCollectionModelWithMessage`.
resp, err := c.mixCoord.WatchChannels(ctx, &datapb.WatchChannelsRequest{
CollectionID: header.CollectionId,
ChannelNames: []string{vchannel},
StartPositions: []*commonpb.KeyDataPair{{Key: funcutil.ToPhysicalChannel(vchannel), Data: startPosition}},
Schema: body.CollectionSchema,
CreateTimestamp: appendResult.TimeTick,
})
return merr.CheckRPCCall(resp.GetStatus(), err)
}
// newCollectionModelWithMessage creates a collection model with the given message.
func newCollectionModelWithMessage(header *message.CreateCollectionMessageHeader, body *message.CreateCollectionRequest, result message.BroadcastResultCreateCollectionMessageV1) *model.Collection {
timetick := result.GetControlChannelResult().TimeTick
// Setup the start position for the vchannels
newCollInfo := newCollectionModel(header, body, timetick)
startPosition := make(map[string][]byte, len(body.PhysicalChannelNames))
for vchannel, appendResult := range result.Results {
if funcutil.IsControlChannel(vchannel) {
// use control channel timetick to setup the create time and update timestamp
newCollInfo.CreateTime = appendResult.TimeTick
newCollInfo.UpdateTimestamp = appendResult.TimeTick
for _, partition := range newCollInfo.Partitions {
partition.PartitionCreatedTimestamp = appendResult.TimeTick
}
continue
}
startPosition[funcutil.ToPhysicalChannel(vchannel)] = adaptor.MustGetMQWrapperIDFromMessage(appendResult.MessageID).Serialize()
// semantically, we should use the last confirmed message id to setup the start position, like following:
// startPosition := adaptor.MustGetMQWrapperIDFromMessage(appendResult.LastConfirmedMessageID).Serialize()
// but currently, the zero message id will be serialized to nil if using woodpecker,
// some code assertions will panic if the start position is nil.
// so we use the message id here, because the vchannel is created by CreateCollectionMessage,
// so the message id will promise to consume all message in the vchannel like LastConfirmedMessageID.
}
newCollInfo.StartPositions = toKeyDataPairs(startPosition)
return newCollInfo
}
// newCollectionModel creates a collection model with the given header, body and timestamp.
func newCollectionModel(header *message.CreateCollectionMessageHeader, body *message.CreateCollectionRequest, ts uint64) *model.Collection {
partitions := make([]*model.Partition, 0, len(body.PartitionIDs))
for idx, partition := range body.PartitionIDs {
partitions = append(partitions, &model.Partition{
PartitionID: partition,
PartitionName: body.PartitionNames[idx],
PartitionCreatedTimestamp: ts,
CollectionID: header.CollectionId,
State: etcdpb.PartitionState_PartitionCreated,
})
}
consistencyLevel, properties := mustConsumeConsistencyLevel(body.CollectionSchema.Properties)
return &model.Collection{
CollectionID: header.CollectionId,
DBID: header.DbId,
Name: body.CollectionSchema.Name,
DBName: body.DbName,
Description: body.CollectionSchema.Description,
AutoID: body.CollectionSchema.AutoID,
Fields: model.UnmarshalFieldModels(body.CollectionSchema.Fields),
StructArrayFields: model.UnmarshalStructArrayFieldModels(body.CollectionSchema.StructArrayFields),
Functions: model.UnmarshalFunctionModels(body.CollectionSchema.Functions),
VirtualChannelNames: body.VirtualChannelNames,
PhysicalChannelNames: body.PhysicalChannelNames,
ShardsNum: int32(len(body.VirtualChannelNames)),
ConsistencyLevel: consistencyLevel,
CreateTime: ts,
State: etcdpb.CollectionState_CollectionCreated,
Partitions: partitions,
Properties: properties,
EnableDynamicField: body.CollectionSchema.EnableDynamicField,
UpdateTimestamp: ts,
}
}
// mustConsumeConsistencyLevel consumes the consistency level from the properties and returns the new properties.
// it panics if the consistency level is not found in the properties, because the consistency level is required.
func mustConsumeConsistencyLevel(properties []*commonpb.KeyValuePair) (commonpb.ConsistencyLevel, []*commonpb.KeyValuePair) {
ok, consistencyLevel := getConsistencyLevel(properties...)
if !ok {
panic(fmt.Errorf("consistency level not found in properties"))
}
newProperties := make([]*commonpb.KeyValuePair, 0, len(properties)-1)
for _, property := range properties {
if property.Key == common.ConsistencyLevel {
continue
}
newProperties = append(newProperties, property)
}
return consistencyLevel, newProperties
}

View File

@ -0,0 +1,114 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/ce"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func (c *Core) broadcastCreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) error {
broadcaster, err := startBroadcastWithCollectionLock(ctx, in.GetDbName(), in.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
collMeta, err := c.meta.GetCollectionByName(ctx, in.GetDbName(), in.GetCollectionName(), typeutil.MaxTimestamp)
if err != nil {
return err
}
if err := checkGeneralCapacity(ctx, 0, 1, 0, c); err != nil {
return err
}
// idempotency check here.
for _, partition := range collMeta.Partitions {
if partition.PartitionName == in.GetPartitionName() {
return errIgnoerdCreatePartition
}
}
cfgMaxPartitionNum := Params.RootCoordCfg.MaxPartitionNum.GetAsInt()
if len(collMeta.Partitions) >= cfgMaxPartitionNum {
return fmt.Errorf("partition number (%d) exceeds max configuration (%d), collection: %s",
len(collMeta.Partitions), cfgMaxPartitionNum, collMeta.Name)
}
partID, err := c.idAllocator.AllocOne()
if err != nil {
return errors.Wrap(err, "failed to allocate partition ID")
}
channels := make([]string, 0, collMeta.ShardsNum+1)
channels = append(channels, streaming.WAL().ControlChannel())
for i := 0; i < int(collMeta.ShardsNum); i++ {
channels = append(channels, collMeta.VirtualChannelNames[i])
}
msg := message.NewCreatePartitionMessageBuilderV1().
WithHeader(&message.CreatePartitionMessageHeader{
CollectionId: collMeta.CollectionID,
PartitionId: partID,
}).
WithBody(&message.CreatePartitionRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_CreatePartition)),
DbName: in.GetDbName(),
CollectionName: in.GetCollectionName(),
PartitionName: in.GetPartitionName(),
DbID: collMeta.DBID,
CollectionID: collMeta.CollectionID,
PartitionID: partID,
}).
WithBroadcast(channels).
MustBuildBroadcast()
_, err = broadcaster.Broadcast(ctx, msg)
return err
}
func (c *DDLCallback) createPartitionV1AckCallback(ctx context.Context, result message.BroadcastResultCreatePartitionMessageV1) error {
header := result.Message.Header()
body := result.Message.MustBody()
partition := &model.Partition{
PartitionID: header.PartitionId,
PartitionName: result.Message.MustBody().PartitionName,
PartitionCreatedTimestamp: result.GetControlChannelResult().TimeTick,
CollectionID: header.CollectionId,
State: pb.PartitionState_PartitionCreated,
}
if err := c.meta.AddPartition(ctx, partition); err != nil {
return errors.Wrap(err, "failed to add partition meta")
}
return c.ExpireCaches(ctx, ce.NewBuilder().
WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(body.DbName),
ce.OptLPCMCollectionName(body.CollectionName),
ce.OptLPCMCollectionID(header.CollectionId),
ce.OptLPCMPartitionName(body.PartitionName),
ce.OptLPCMMsgType(commonpb.MsgType_CreatePartition),
),
result.GetControlChannelResult().TimeTick)
}

View File

@ -24,38 +24,13 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestDDLCallbacksDatabaseDDL(t *testing.T) {
initStreamingSystem()
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
ss, err := rootcoord.NewSuffixSnapshot(catalogKV, rootcoord.SnapshotsSep, path, rootcoord.SnapshotPrefix)
require.NoError(t, err)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{
catalog: rootcoord.NewCatalog(catalogKV, ss),
names: newNameDb(),
aliases: newNameDb(),
dbName2Meta: make(map[string]*model.Database),
}),
withValidProxyManager(),
withValidIDAllocator(),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
core := initStreamingSystemAndCore(t)
// Create a new database
status, err := core.CreateDatabase(context.Background(), &milvuspb.CreateDatabaseRequest{

View File

@ -0,0 +1,163 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/ce"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func (c *Core) broadcastDropCollectionV1(ctx context.Context, req *milvuspb.DropCollectionRequest) error {
broadcaster, err := startBroadcastWithCollectionLock(ctx, req.GetDbName(), req.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
dropCollectionTask := &dropCollectionTask{
Core: c,
Req: req,
}
if err := dropCollectionTask.Prepare(ctx); err != nil {
return err
}
channels := make([]string, 0, len(dropCollectionTask.vchannels)+1)
channels = append(channels, streaming.WAL().ControlChannel())
channels = append(channels, dropCollectionTask.vchannels...)
msg := message.NewDropCollectionMessageBuilderV1().
WithHeader(dropCollectionTask.header).
WithBody(dropCollectionTask.body).
WithBroadcast(channels).
MustBuildBroadcast()
if _, err := broadcaster.Broadcast(ctx, msg); err != nil {
return err
}
return nil
}
// dropCollectionV1AckCallback is called when the drop collection message is acknowledged
func (c *DDLCallback) dropCollectionV1AckCallback(ctx context.Context, result message.BroadcastResultDropCollectionMessageV1) error {
msg := result.Message
header := msg.Header()
body := msg.MustBody()
for vchannel, result := range result.Results {
collectionID := msg.Header().CollectionId
if funcutil.IsControlChannel(vchannel) {
// when the control channel is acknowledged, we should do the following steps:
// 1. release the collection from querycoord first.
dropLoadConfigMsg := message.NewDropLoadConfigMessageBuilderV2().
WithHeader(&message.DropLoadConfigMessageHeader{
DbId: msg.Header().DbId,
CollectionId: collectionID,
}).
WithBody(&message.DropLoadConfigMessageBody{}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}).
MustBuildBroadcast().
WithBroadcastID(msg.BroadcastHeader().BroadcastID)
if err := registry.CallMessageAckCallback(ctx, dropLoadConfigMsg, map[string]*message.AppendResult{
streaming.WAL().ControlChannel(): result,
}); err != nil {
return errors.Wrap(err, "failed to release collection")
}
// 2. drop the collection index.
dropIndexMsg := message.NewDropIndexMessageBuilderV2().
WithHeader(&message.DropIndexMessageHeader{
CollectionId: collectionID,
}).
WithBody(&message.DropIndexMessageBody{}).
WithBroadcast([]string{streaming.WAL().ControlChannel()}).
MustBuildBroadcast().
WithBroadcastID(msg.BroadcastHeader().BroadcastID)
if err := registry.CallMessageAckCallback(ctx, dropIndexMsg, map[string]*message.AppendResult{
streaming.WAL().ControlChannel(): result,
}); err != nil {
return errors.Wrap(err, "failed to drop collection index")
}
// 3. drop the collection meta itself.
if err := c.meta.DropCollection(ctx, collectionID, result.TimeTick); err != nil {
return errors.Wrap(err, "failed to drop collection")
}
continue
}
// Drop virtual channel data when the vchannel is acknowledged.
resp, err := c.mixCoord.DropVirtualChannel(ctx, &datapb.DropVirtualChannelRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ChannelName: vchannel,
})
if err := merr.CheckRPCCall(resp, err); err != nil {
return errors.Wrap(err, "failed to drop virtual channel")
}
}
// add the collection tombstone to the sweeper.
c.tombstoneSweeper.AddTombstone(newCollectionTombstone(c.meta, c.broker, header.CollectionId))
// expire the collection meta cache on proxy.
return c.ExpireCaches(ctx, ce.NewBuilder().WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(body.DbName),
ce.OptLPCMCollectionName(body.CollectionName),
ce.OptLPCMCollectionID(header.CollectionId),
ce.OptLPCMMsgType(commonpb.MsgType_DropCollection)).Build(),
result.GetControlChannelResult().TimeTick)
}
// newCollectionTombstone creates a new collection tombstone.
func newCollectionTombstone(meta IMetaTable, broker Broker, collectionID int64) *collectionTombstone {
return &collectionTombstone{
meta: meta,
broker: broker,
collectionID: collectionID,
}
}
type collectionTombstone struct {
meta IMetaTable
broker Broker
collectionID int64
}
func (t *collectionTombstone) ID() string {
return fmt.Sprintf("c:%d", t.collectionID)
}
func (t *collectionTombstone) ConfirmCanBeRemoved(ctx context.Context) (bool, error) {
return t.broker.GcConfirm(ctx, t.collectionID, common.AllPartitionsID), nil
}
func (t *collectionTombstone) Remove(ctx context.Context) error {
return t.meta.RemoveCollection(ctx, t.collectionID, 0)
}

View File

@ -0,0 +1,148 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message/ce"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func (c *Core) broadcastDropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) error {
if in.GetPartitionName() == Params.CommonCfg.DefaultPartitionName.GetValue() {
return errors.New("default partition cannot be deleted")
}
broadcaster, err := startBroadcastWithCollectionLock(ctx, in.GetDbName(), in.GetCollectionName())
if err != nil {
return err
}
defer broadcaster.Close()
collMeta, err := c.meta.GetCollectionByName(ctx, in.GetDbName(), in.GetCollectionName(), typeutil.MaxTimestamp)
if err != nil {
// Is this idempotent?
return err
}
partID := common.InvalidPartitionID
for _, partition := range collMeta.Partitions {
if partition.PartitionName == in.GetPartitionName() {
partID = partition.PartitionID
break
}
}
if partID == common.InvalidPartitionID {
return errIgnoredDropPartition
}
channels := make([]string, 0, collMeta.ShardsNum+1)
channels = append(channels, streaming.WAL().ControlChannel())
for i := 0; i < int(collMeta.ShardsNum); i++ {
channels = append(channels, collMeta.VirtualChannelNames[i])
}
msg := message.NewDropPartitionMessageBuilderV1().
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: collMeta.CollectionID,
PartitionId: partID,
}).
WithBody(&message.DropPartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropPartition,
},
DbName: in.GetDbName(),
CollectionName: in.GetCollectionName(),
PartitionName: in.GetPartitionName(),
DbID: collMeta.DBID,
CollectionID: collMeta.CollectionID,
PartitionID: partID,
}).
WithBroadcast(channels).
MustBuildBroadcast()
_, err = broadcaster.Broadcast(ctx, msg)
return err
}
func (c *DDLCallback) dropPartitionV1AckCallback(ctx context.Context, result message.BroadcastResultDropPartitionMessageV1) error {
header := result.Message.Header()
body := result.Message.MustBody()
for vchannel := range result.Results {
if funcutil.IsControlChannel(vchannel) {
continue
}
// drop all historical partition data when the vchannel is acknowledged.
if err := c.mixCoord.NotifyDropPartition(ctx, vchannel, []int64{header.PartitionId}); err != nil {
return err
}
}
if err := c.meta.DropPartition(ctx, header.CollectionId, header.PartitionId, result.GetControlChannelResult().TimeTick); err != nil {
return err
}
// add the partition tombstone to the sweeper.
c.tombstoneSweeper.AddTombstone(newPartitionTombstone(c.meta, c.broker, header.CollectionId, header.PartitionId))
// expire the partition meta cache on proxy.
return c.ExpireCaches(ctx, ce.NewBuilder().
WithLegacyProxyCollectionMetaCache(
ce.OptLPCMDBName(body.DbName),
ce.OptLPCMCollectionName(body.CollectionName),
ce.OptLPCMCollectionID(header.CollectionId),
ce.OptLPCMPartitionName(body.PartitionName),
ce.OptLPCMMsgType(commonpb.MsgType_DropPartition),
),
result.GetControlChannelResult().TimeTick)
}
// newPartitionTombstone creates a new partition tombstone.
func newPartitionTombstone(meta IMetaTable, broker Broker, collectionID int64, partitionID int64) *partitionTombstone {
return &partitionTombstone{
meta: meta,
broker: broker,
collectionID: collectionID,
partitionID: partitionID,
}
}
type partitionTombstone struct {
meta IMetaTable
broker Broker
collectionID int64
partitionID int64
}
func (t *partitionTombstone) ID() string {
return fmt.Sprintf("p:%d:%d", t.collectionID, t.partitionID)
}
func (t *partitionTombstone) ConfirmCanBeRemoved(ctx context.Context) (bool, error) {
return t.broker.GcConfirm(ctx, t.collectionID, t.partitionID), nil
}
func (t *partitionTombstone) Remove(ctx context.Context) error {
return t.meta.RemoveCollection(ctx, t.collectionID, 0)
}

View File

@ -24,10 +24,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -35,21 +31,9 @@ import (
)
func TestDDLCallbacksRBACCredential(t *testing.T) {
initStreamingSystem()
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
core := initStreamingSystemAndCore(t)
testUserName := "user" + funcutil.RandomString(10)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{catalog: rootcoord.NewCatalog(catalogKV, nil)}),
withValidProxyManager(),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
// Delete a not existed credential should succeed
status, err := core.DeleteCredential(context.Background(), &milvuspb.DeleteCredentialRequest{
Username: testUserName,

View File

@ -23,28 +23,12 @@ import (
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func TestDDLCallbacksRBACPrivilege(t *testing.T) {
initStreamingSystem()
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{catalog: rootcoord.NewCatalog(catalogKV, nil)}),
withValidProxyManager(),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
core := initStreamingSystemAndCore(t)
// Create a new role.
targetRoleName := "newRole"
@ -164,18 +148,7 @@ func TestDDLCallbacksRBACPrivilege(t *testing.T) {
}
func TestDDLCallbacksRBACPrivilegeGroup(t *testing.T) {
initStreamingSystem()
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{catalog: rootcoord.NewCatalog(catalogKV, nil)}),
withValidProxyManager(),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
core := initStreamingSystemAndCore(t)
groupName := "group1"
status, err := core.CreatePrivilegeGroup(context.Background(), &milvuspb.CreatePrivilegeGroupRequest{

View File

@ -24,28 +24,12 @@ import (
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func TestDDLCallbacksRBACRestore(t *testing.T) {
initStreamingSystem()
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{catalog: rootcoord.NewCatalog(catalogKV, nil)}),
withValidProxyManager(),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
core := initStreamingSystemAndCore(t)
ctx := context.Background()
rbacMeta := &milvuspb.RBACMeta{

View File

@ -24,10 +24,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
@ -36,18 +32,7 @@ import (
)
func TestDDLCallbacksRBACRole(t *testing.T) {
initStreamingSystem()
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{catalog: rootcoord.NewCatalog(catalogKV, nil)}),
withValidProxyManager(),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
core := initStreamingSystemAndCore(t)
// Test drop builtin role should return error
roleDbAdmin := "db_admin"

View File

@ -25,50 +25,39 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/log"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type dropCollectionTask struct {
baseTask
Req *milvuspb.DropCollectionRequest
*Core
Req *milvuspb.DropCollectionRequest
header *message.DropCollectionMessageHeader
body *message.DropCollectionRequest
vchannels []string
}
func (t *dropCollectionTask) validate(ctx context.Context) error {
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_DropCollection); err != nil {
return err
}
if t.core.meta.IsAlias(ctx, t.Req.GetDbName(), t.Req.GetCollectionName()) {
if t.meta.IsAlias(ctx, t.Req.GetDbName(), t.Req.GetCollectionName()) {
return fmt.Errorf("cannot drop the collection via alias = %s", t.Req.CollectionName)
}
return nil
}
func (t *dropCollectionTask) Prepare(ctx context.Context) error {
return t.validate(ctx)
}
func (t *dropCollectionTask) Execute(ctx context.Context) error {
// use max ts to check if latest collection exists.
// we cannot handle case that
// dropping collection with `ts1` but a collection exists in catalog with newer ts which is bigger than `ts1`.
// fortunately, if ddls are promised to execute in sequence, then everything is OK. The `ts1` will always be latest.
collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
if errors.Is(err, merr.ErrCollectionNotFound) || errors.Is(err, merr.ErrDatabaseNotFound) {
// make dropping collection idempotent.
log.Ctx(ctx).Warn("drop non-existent collection", zap.String("collection", t.Req.GetCollectionName()), zap.String("database", t.Req.GetDbName()))
return nil
}
collMeta, err := t.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), typeutil.MaxTimestamp)
if err != nil {
if errors.Is(err, merr.ErrCollectionNotFound) || errors.Is(err, merr.ErrDatabaseNotFound) {
return errIgnoredDropCollection
}
return err
}
// meta cache of all aliases should also be cleaned.
aliases := t.core.meta.ListAliasesByID(ctx, collMeta.CollectionID)
aliases := t.meta.ListAliasesByID(ctx, collMeta.CollectionID)
// Check if all aliases have been dropped.
if len(aliases) > 0 {
@ -77,79 +66,25 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error {
return err
}
ts := t.GetTs()
return executeDropCollectionTaskSteps(ctx,
t.core, collMeta, t.Req.GetDbName(), aliases,
t.Req.GetBase().GetReplicateInfo().GetIsReplicate(),
ts)
// fill the message body and header
// TODO: cleanupMetricsStep
t.header = &message.DropCollectionMessageHeader{
CollectionId: collMeta.CollectionID,
DbId: collMeta.DBID,
}
t.body = &message.DropCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
},
CollectionID: collMeta.CollectionID,
DbID: collMeta.DBID,
CollectionName: t.Req.CollectionName,
DbName: collMeta.DBName,
}
t.vchannels = collMeta.VirtualChannelNames
return nil
}
func (t *dropCollectionTask) GetLockerKey() LockerKey {
collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0)
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(t.Req.GetDbName(), false),
NewCollectionLockerKey(collection, true),
)
}
func executeDropCollectionTaskSteps(ctx context.Context,
core *Core,
col *model.Collection,
dbName string,
alias []string,
isReplicate bool,
ts Timestamp,
) error {
redoTask := newBaseRedoTask(core.stepExecutor)
redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: core},
dbName: dbName,
collectionNames: append(alias, col.Name),
collectionID: col.CollectionID,
ts: ts,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_DropCollection)},
})
redoTask.AddSyncStep(&changeCollectionStateStep{
baseStep: baseStep{core: core},
collectionID: col.CollectionID,
state: pb.CollectionState_CollectionDropping,
ts: ts,
})
redoTask.AddSyncStep(&cleanupMetricsStep{
baseStep: baseStep{core: core},
dbName: dbName,
collectionName: col.Name,
})
redoTask.AddAsyncStep(&releaseCollectionStep{
baseStep: baseStep{core: core},
collectionID: col.CollectionID,
})
redoTask.AddAsyncStep(&dropIndexStep{
baseStep: baseStep{core: core},
collID: col.CollectionID,
partIDs: nil,
})
redoTask.AddAsyncStep(&deleteCollectionDataStep{
baseStep: baseStep{core: core},
coll: col,
isSkip: isReplicate,
})
redoTask.AddAsyncStep(&removeDmlChannelsStep{
baseStep: baseStep{core: core},
pChannels: col.PhysicalChannelNames,
})
redoTask.AddAsyncStep(newConfirmGCStep(core, col.CollectionID, allPartition))
redoTask.AddAsyncStep(&deleteCollectionMetaStep{
baseStep: baseStep{core: core},
collectionID: col.CollectionID,
// This ts is less than the ts when we notify data nodes to drop collection, but it's OK since we have already
// marked this collection as deleted. If we want to make this ts greater than the notification's ts, we should
// wrap a step who will have these three children and connect them with ts.
ts: ts,
})
return redoTask.Execute(ctx)
func (t *dropCollectionTask) Prepare(ctx context.Context) error {
return t.validate(ctx)
}

View File

@ -18,11 +18,8 @@ package rootcoord
import (
"context"
"strings"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@ -30,21 +27,12 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
)
func Test_dropCollectionTask_Prepare(t *testing.T) {
t.Run("invalid msg type", func(t *testing.T) {
task := &dropCollectionTask{
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DescribeCollection},
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("drop via alias", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
@ -57,7 +45,7 @@ func Test_dropCollectionTask_Prepare(t *testing.T) {
core := newTestCore(withMeta(meta))
task := &dropCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Core: core,
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
CollectionName: collectionName,
@ -67,6 +55,44 @@ func Test_dropCollectionTask_Prepare(t *testing.T) {
assert.Error(t, err)
})
t.Run("collection not found", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().IsAlias(mock.Anything, mock.Anything, mock.Anything).Return(false)
meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound)
core := newTestCore(withMeta(meta))
task := &dropCollectionTask{
Core: core,
Req: &milvuspb.DropCollectionRequest{
CollectionName: collectionName,
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("collection has aliases", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().IsAlias(mock.Anything, mock.Anything, mock.Anything).Return(false)
meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{
CollectionID: 1,
DBID: 1,
State: pb.CollectionState_CollectionCreated,
VirtualChannelNames: []string{"vchannel1"},
}, nil)
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{"alias1"})
core := newTestCore(withMeta(meta))
task := &dropCollectionTask{
Core: core,
Req: &milvuspb.DropCollectionRequest{
CollectionName: collectionName,
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
@ -76,10 +102,18 @@ func Test_dropCollectionTask_Prepare(t *testing.T) {
mock.Anything,
mock.Anything,
).Return(false)
meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&model.Collection{
CollectionID: 1,
DBName: "db1",
DBID: 1,
State: pb.CollectionState_CollectionCreated,
VirtualChannelNames: []string{"vchannel1"},
}, nil)
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).Return([]string{})
core := newTestCore(withMeta(meta))
task := &dropCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Core: core,
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
CollectionName: collectionName,
@ -87,237 +121,10 @@ func Test_dropCollectionTask_Prepare(t *testing.T) {
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
})
}
func Test_dropCollectionTask_Execute(t *testing.T) {
t.Run("drop non-existent collection", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything, // context.Context.
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil, func(ctx context.Context, dbName string, name string, ts Timestamp) error {
if collectionName == name {
return merr.WrapErrCollectionNotFound(collectionName)
}
return errors.New("error mock GetCollectionByName")
})
core := newTestCore(withMeta(meta))
task := &dropCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
CollectionName: collectionName,
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
task.Req.CollectionName = collectionName + "_test"
err = task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("failed to expire cache", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName}
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything, // context.Context
mock.Anything,
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
meta.On("ListAliasesByID",
mock.Anything,
mock.AnythingOfType("int64"),
).Return([]string{})
core := newTestCore(withInvalidProxyManager(), withMeta(meta))
task := &dropCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
CollectionName: collectionName,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("failed to change collection state", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName}
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
meta.On("ChangeCollectionState",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(errors.New("error mock ChangeCollectionState"))
meta.On("ListAliasesByID",
mock.Anything,
mock.Anything,
).Return([]string{})
core := newTestCore(withValidProxyManager(), withMeta(meta))
task := &dropCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
CollectionName: collectionName,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("aliases have not been dropped", func(t *testing.T) {
defer cleanTestEnv()
collectionName := funcutil.GenRandomStr()
shardNum := 2
ticker := newRocksMqTtSynchronizer()
pchans := ticker.getDmlChannelNames(shardNum)
ticker.addDmlChannels(pchans...)
coll := &model.Collection{Name: collectionName, ShardsNum: int32(shardNum), PhysicalChannelNames: pchans}
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(coll.Clone(), nil)
meta.EXPECT().ListAliasesByID(mock.Anything, mock.Anything).
Return([]string{"mock-alias-0", "mock-alias-1"})
core := newTestCore(
withMeta(meta),
withTtSynchronizer(ticker))
task := &dropCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
CollectionName: collectionName,
},
}
err := task.Execute(context.Background())
assert.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "please remove all aliases"))
})
t.Run("normal case, redo", func(t *testing.T) {
defer cleanTestEnv()
confirmGCInterval = time.Millisecond
defer restoreConfirmGCInterval()
collectionName := funcutil.GenRandomStr()
shardNum := 2
ticker := newRocksMqTtSynchronizer()
pchans := ticker.getDmlChannelNames(shardNum)
ticker.addDmlChannels(pchans...)
coll := &model.Collection{Name: collectionName, ShardsNum: int32(shardNum), PhysicalChannelNames: pchans}
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
meta.On("ChangeCollectionState",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.On("ListAliasesByID",
mock.Anything,
mock.Anything,
).Return([]string{})
removeCollectionMetaCalled := false
removeCollectionMetaChan := make(chan struct{}, 1)
meta.On("RemoveCollection",
mock.Anything,
mock.Anything,
mock.Anything,
).Return(func(ctx context.Context, collID UniqueID, ts Timestamp) error {
removeCollectionMetaCalled = true
removeCollectionMetaChan <- struct{}{}
return nil
})
broker := newMockBroker()
releaseCollectionCalled := false
releaseCollectionChan := make(chan struct{}, 1)
broker.ReleaseCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
releaseCollectionCalled = true
releaseCollectionChan <- struct{}{}
return nil
}
dropIndexCalled := false
dropIndexChan := make(chan struct{}, 1)
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
dropIndexCalled = true
dropIndexChan <- struct{}{}
time.Sleep(confirmGCInterval)
return nil
}
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
return true
}
gc := mockrootcoord.NewGarbageCollector(t)
deleteCollectionCalled := false
deleteCollectionChan := make(chan struct{}, 1)
gc.EXPECT().GcCollectionData(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, coll *model.Collection) (Timestamp, error) {
deleteCollectionCalled = true
deleteCollectionChan <- struct{}{}
return 0, nil
})
core := newTestCore(
withValidProxyManager(),
withMeta(meta),
withBroker(broker),
withGarbageCollector(gc),
withTtSynchronizer(ticker))
task := &dropCollectionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
CollectionName: collectionName,
},
}
err := task.Execute(context.Background())
assert.NoError(t, err)
// check if redo worked.
<-releaseCollectionChan
assert.True(t, releaseCollectionCalled)
<-dropIndexChan
assert.True(t, dropIndexCalled)
<-deleteCollectionChan
assert.True(t, deleteCollectionCalled)
<-removeCollectionMetaChan
assert.True(t, removeCollectionMetaCalled)
assert.Equal(t, int64(1), task.header.CollectionId)
assert.Equal(t, int64(1), task.header.DbId)
assert.Equal(t, collectionName, task.body.CollectionName)
assert.Equal(t, "db1", task.body.DbName)
assert.Equal(t, []string{"vchannel1"}, task.vchannels)
})
}

View File

@ -1,137 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
)
type dropPartitionTask struct {
baseTask
Req *milvuspb.DropPartitionRequest
collMeta *model.Collection
}
func (t *dropPartitionTask) Prepare(ctx context.Context) error {
if err := CheckMsgType(t.Req.GetBase().GetMsgType(), commonpb.MsgType_DropPartition); err != nil {
return err
}
if t.Req.GetPartitionName() == Params.CommonCfg.DefaultPartitionName.GetValue() {
return errors.New("default partition cannot be deleted")
}
collMeta, err := t.core.meta.GetCollectionByName(ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), t.GetTs())
if err != nil {
// Is this idempotent?
return err
}
t.collMeta = collMeta
return nil
}
func (t *dropPartitionTask) Execute(ctx context.Context) error {
partID := common.InvalidPartitionID
for _, partition := range t.collMeta.Partitions {
if partition.PartitionName == t.Req.GetPartitionName() {
partID = partition.PartitionID
break
}
}
if partID == common.InvalidPartitionID {
log.Ctx(ctx).Warn("drop an non-existent partition", zap.String("collection", t.Req.GetCollectionName()), zap.String("partition", t.Req.GetPartitionName()))
// make dropping partition idempotent.
return nil
}
return executeDropPartitionTaskSteps(ctx, t.core,
t.Req.GetPartitionName(), partID,
t.collMeta, t.Req.GetDbName(),
t.Req.GetBase().GetReplicateInfo().GetIsReplicate(), t.GetTs())
}
func (t *dropPartitionTask) GetLockerKey() LockerKey {
collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0)
return NewLockerKeyChain(
NewClusterLockerKey(false),
NewDatabaseLockerKey(t.Req.GetDbName(), false),
NewCollectionLockerKey(collection, true),
)
}
func executeDropPartitionTaskSteps(ctx context.Context,
core *Core,
partitionName string,
partitionID UniqueID,
col *model.Collection,
dbName string,
isReplicate bool,
ts Timestamp,
) error {
redoTask := newBaseRedoTask(core.stepExecutor)
redoTask.AddSyncStep(&expireCacheStep{
baseStep: baseStep{core: core},
dbName: dbName,
collectionNames: []string{col.Name},
collectionID: col.CollectionID,
partitionName: partitionName,
ts: ts,
opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_DropPartition)},
})
redoTask.AddSyncStep(&changePartitionStateStep{
baseStep: baseStep{core: core},
collectionID: col.CollectionID,
partitionID: partitionID,
state: pb.PartitionState_PartitionDropping,
ts: ts,
})
redoTask.AddAsyncStep(&deletePartitionDataStep{
baseStep: baseStep{core: core},
pchans: col.PhysicalChannelNames,
vchans: col.VirtualChannelNames,
partition: &model.Partition{
PartitionID: partitionID,
PartitionName: partitionName,
CollectionID: col.CollectionID,
},
isSkip: isReplicate,
})
redoTask.AddAsyncStep(newConfirmGCStep(core, col.CollectionID, partitionID))
redoTask.AddAsyncStep(&removePartitionMetaStep{
baseStep: baseStep{core: core},
dbID: col.DBID,
collectionID: col.CollectionID,
partitionID: partitionID,
// This ts is less than the ts when we notify data nodes to drop partition, but it's OK since we have already
// marked this partition as deleted. If we want to make this ts greater than the notification's ts, we should
// wrap a step who will have these children and connect them with ts.
ts: ts,
})
return redoTask.Execute(ctx)
}

View File

@ -1,221 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
)
func Test_dropPartitionTask_Prepare(t *testing.T) {
t.Run("invalid msg type", func(t *testing.T) {
task := &dropPartitionTask{
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropCollection},
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("drop default partition", func(t *testing.T) {
task := &dropPartitionTask{
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(),
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("failed to get collection meta", func(t *testing.T) {
core := newTestCore(withInvalidMeta())
task := &dropPartitionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
},
}
err := task.Prepare(context.Background())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName}
meta := mockrootcoord.NewIMetaTable(t)
meta.On("GetCollectionByName",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(coll.Clone(), nil)
core := newTestCore(withMeta(meta))
task := &dropPartitionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
CollectionName: collectionName,
},
}
err := task.Prepare(context.Background())
assert.NoError(t, err)
assert.True(t, coll.Equal(*task.collMeta))
})
}
func Test_dropPartitionTask_Execute(t *testing.T) {
t.Run("drop non-existent partition", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{}}
task := &dropPartitionTask{
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
CollectionName: collectionName,
PartitionName: partitionName,
},
collMeta: coll.Clone(),
}
err := task.Execute(context.Background())
assert.NoError(t, err)
})
t.Run("failed to expire cache", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{{PartitionName: partitionName}}}
core := newTestCore(withInvalidProxyManager())
task := &dropPartitionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
CollectionName: collectionName,
PartitionName: partitionName,
},
collMeta: coll.Clone(),
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("failed to change partition state", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{{PartitionName: partitionName}}}
core := newTestCore(withValidProxyManager(), withInvalidMeta())
task := &dropPartitionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
CollectionName: collectionName,
PartitionName: partitionName,
},
collMeta: coll.Clone(),
}
err := task.Execute(context.Background())
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
confirmGCInterval = time.Millisecond
defer restoreConfirmGCInterval()
collectionName := funcutil.GenRandomStr()
partitionName := funcutil.GenRandomStr()
coll := &model.Collection{Name: collectionName, Partitions: []*model.Partition{{PartitionName: partitionName}}}
removePartitionMetaCalled := false
removePartitionMetaChan := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("ChangePartitionState",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta.On("RemovePartition",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(func(ctx context.Context, dbID int64, collectionID int64, partitionID int64, ts uint64) error {
removePartitionMetaCalled = true
removePartitionMetaChan <- struct{}{}
return nil
})
gc := mockrootcoord.NewGarbageCollector(t)
deletePartitionCalled := false
deletePartitionChan := make(chan struct{}, 1)
gc.EXPECT().GcPartitionData(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, pChannels, vchannel []string, coll *model.Partition) (Timestamp, error) {
deletePartitionChan <- struct{}{}
deletePartitionCalled = true
time.Sleep(confirmGCInterval)
return 0, nil
})
broker := newMockBroker()
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
return true
}
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return nil
}
broker.ReleasePartitionsFunc = func(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
return nil
}
core := newTestCore(
withValidProxyManager(),
withMeta(meta),
withGarbageCollector(gc),
withBroker(broker))
task := &dropPartitionTask{
baseTask: newBaseTask(context.Background(), core),
Req: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_DropPartition},
CollectionName: collectionName,
PartitionName: partitionName,
},
collMeta: coll.Clone(),
}
err := task.Execute(context.Background())
assert.NoError(t, err)
// check if redo worked.
<-removePartitionMetaChan
assert.True(t, removePartitionMetaCalled)
<-deletePartitionChan
assert.True(t, deletePartitionCalled)
})
}

View File

@ -1,352 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/streamingutil"
ms "github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
)
//go:generate mockery --name=GarbageCollector --outpkg=mockrootcoord --filename=garbage_collector.go --with-expecter --testonly
type GarbageCollector interface {
ReDropCollection(collMeta *model.Collection, ts Timestamp)
RemoveCreatingCollection(collMeta *model.Collection)
ReDropPartition(dbID int64, pChannels, vchannels []string, partition *model.Partition, ts Timestamp)
RemoveCreatingPartition(dbID int64, partition *model.Partition, ts Timestamp)
GcCollectionData(ctx context.Context, coll *model.Collection) (ddlTs Timestamp, err error)
GcPartitionData(ctx context.Context, pChannels, vchannels []string, partition *model.Partition) (ddlTs Timestamp, err error)
}
type bgGarbageCollector struct {
s *Core
}
func newBgGarbageCollector(s *Core) *bgGarbageCollector {
return &bgGarbageCollector{s: s}
}
func (c *bgGarbageCollector) ReDropCollection(collMeta *model.Collection, ts Timestamp) {
// TODO: remove this after data gc can be notified by rpc.
c.s.chanTimeTick.addDmlChannels(collMeta.PhysicalChannelNames...)
redo := newBaseRedoTask(c.s.stepExecutor)
redo.AddAsyncStep(&releaseCollectionStep{
baseStep: baseStep{core: c.s},
collectionID: collMeta.CollectionID,
})
redo.AddAsyncStep(&dropIndexStep{
baseStep: baseStep{core: c.s},
collID: collMeta.CollectionID,
partIDs: nil,
})
redo.AddAsyncStep(&deleteCollectionDataStep{
baseStep: baseStep{core: c.s},
coll: collMeta,
isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(),
})
redo.AddAsyncStep(&removeDmlChannelsStep{
baseStep: baseStep{core: c.s},
pChannels: collMeta.PhysicalChannelNames,
})
redo.AddAsyncStep(newConfirmGCStep(c.s, collMeta.CollectionID, allPartition))
redo.AddAsyncStep(&deleteCollectionMetaStep{
baseStep: baseStep{core: c.s},
collectionID: collMeta.CollectionID,
// This ts is less than the ts when we notify data nodes to drop collection, but it's OK since we have already
// marked this collection as deleted. If we want to make this ts greater than the notification's ts, we should
// wrap a step who will have these three children and connect them with ts.
ts: ts,
})
// err is ignored since no sync steps will be executed.
_ = redo.Execute(context.Background())
}
func (c *bgGarbageCollector) RemoveCreatingCollection(collMeta *model.Collection) {
// TODO: remove this after data gc can be notified by rpc.
c.s.chanTimeTick.addDmlChannels(collMeta.PhysicalChannelNames...)
redo := newBaseRedoTask(c.s.stepExecutor)
redo.AddAsyncStep(&unwatchChannelsStep{
baseStep: baseStep{core: c.s},
collectionID: collMeta.CollectionID,
channels: collectionChannels{
virtualChannels: collMeta.VirtualChannelNames,
physicalChannels: collMeta.PhysicalChannelNames,
},
isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(),
})
redo.AddAsyncStep(&removeDmlChannelsStep{
baseStep: baseStep{core: c.s},
pChannels: collMeta.PhysicalChannelNames,
})
redo.AddAsyncStep(&deleteCollectionMetaStep{
baseStep: baseStep{core: c.s},
collectionID: collMeta.CollectionID,
// When we undo createCollectionTask, this ts may be less than the ts when unwatch channels.
ts: collMeta.CreateTime,
})
// err is ignored since no sync steps will be executed.
_ = redo.Execute(context.Background())
}
func (c *bgGarbageCollector) ReDropPartition(dbID int64, pChannels, vchannels []string, partition *model.Partition, ts Timestamp) {
// TODO: remove this after data gc can be notified by rpc.
c.s.chanTimeTick.addDmlChannels(pChannels...)
redo := newBaseRedoTask(c.s.stepExecutor)
redo.AddAsyncStep(&deletePartitionDataStep{
baseStep: baseStep{core: c.s},
pchans: pChannels,
vchans: vchannels,
partition: partition,
isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(),
})
redo.AddAsyncStep(&removeDmlChannelsStep{
baseStep: baseStep{core: c.s},
pChannels: pChannels,
})
redo.AddAsyncStep(newConfirmGCStep(c.s, partition.CollectionID, partition.PartitionID))
redo.AddAsyncStep(&removePartitionMetaStep{
baseStep: baseStep{core: c.s},
dbID: dbID,
collectionID: partition.CollectionID,
partitionID: partition.PartitionID,
// This ts is less than the ts when we notify data nodes to drop partition, but it's OK since we have already
// marked this partition as deleted. If we want to make this ts greater than the notification's ts, we should
// wrap a step who will have these children and connect them with ts.
ts: ts,
})
// err is ignored since no sync steps will be executed.
_ = redo.Execute(context.Background())
}
func (c *bgGarbageCollector) RemoveCreatingPartition(dbID int64, partition *model.Partition, ts Timestamp) {
redoTask := newBaseRedoTask(c.s.stepExecutor)
redoTask.AddAsyncStep(&releasePartitionsStep{
baseStep: baseStep{core: c.s},
collectionID: partition.CollectionID,
partitionIDs: []int64{partition.PartitionID},
})
redoTask.AddAsyncStep(&removePartitionMetaStep{
baseStep: baseStep{core: c.s},
dbID: dbID,
collectionID: partition.CollectionID,
partitionID: partition.PartitionID,
ts: ts,
})
// err is ignored since no sync steps will be executed.
_ = redoTask.Execute(context.Background())
}
func (c *bgGarbageCollector) notifyCollectionGc(ctx context.Context, coll *model.Collection) (ddlTs Timestamp, err error) {
if streamingutil.IsStreamingServiceEnabled() {
notifier := snmanager.NewStreamingReadyNotifier()
if err := snmanager.StaticStreamingNodeManager.RegisterStreamingEnabledListener(ctx, notifier); err != nil {
return 0, err
}
if notifier.IsReady() {
// streaming service is ready, so we release the ready notifier and send it into streaming service.
notifier.Release()
return c.notifyCollectionGcByStreamingService(ctx, coll)
}
// streaming service is not ready, so we send it into msgstream.
defer notifier.Release()
}
ts, err := c.s.tsoAllocator.GenerateTSO(1)
if err != nil {
return 0, err
}
msgPack := ms.MsgPack{}
msg := &ms.DropCollectionMsg{
BaseMsg: ms.BaseMsg{
Ctx: ctx,
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
},
DropCollectionRequest: c.generateDropRequest(coll, ts),
}
msgPack.Msgs = append(msgPack.Msgs, msg)
if err := c.s.chanTimeTick.broadcastDmlChannels(coll.PhysicalChannelNames, &msgPack); err != nil {
return 0, err
}
return ts, nil
}
func (c *bgGarbageCollector) generateDropRequest(coll *model.Collection, ts uint64) *msgpb.DropCollectionRequest {
return &msgpb.DropCollectionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DropCollection),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(c.s.session.GetServerID()),
),
CollectionName: coll.Name,
CollectionID: coll.CollectionID,
}
}
func (c *bgGarbageCollector) notifyCollectionGcByStreamingService(ctx context.Context, coll *model.Collection) (uint64, error) {
req := c.generateDropRequest(coll, 0) // ts is given by streamingnode.
msgs := make([]message.MutableMessage, 0, len(coll.VirtualChannelNames))
for _, vchannel := range coll.VirtualChannelNames {
msg, err := message.NewDropCollectionMessageBuilderV1().
WithVChannel(vchannel).
WithHeader(&message.DropCollectionMessageHeader{
CollectionId: coll.CollectionID,
}).
WithBody(req).
BuildMutable()
if err != nil {
return 0, err
}
msgs = append(msgs, msg)
}
resp := streaming.WAL().AppendMessages(ctx, msgs...)
if err := resp.UnwrapFirstError(); err != nil {
return 0, err
}
return resp.MaxTimeTick(), nil
}
func (c *bgGarbageCollector) notifyPartitionGc(ctx context.Context, pChannels []string, partition *model.Partition) (ddlTs Timestamp, err error) {
ts, err := c.s.tsoAllocator.GenerateTSO(1)
if err != nil {
return 0, err
}
msgPack := ms.MsgPack{}
msg := &ms.DropPartitionMsg{
BaseMsg: ms.BaseMsg{
Ctx: ctx,
BeginTimestamp: ts,
EndTimestamp: ts,
HashValues: []uint32{0},
},
DropPartitionRequest: &msgpb.DropPartitionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DropPartition),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(c.s.session.GetServerID()),
),
PartitionName: partition.PartitionName,
CollectionID: partition.CollectionID,
PartitionID: partition.PartitionID,
},
}
msgPack.Msgs = append(msgPack.Msgs, msg)
if err := c.s.chanTimeTick.broadcastDmlChannels(pChannels, &msgPack); err != nil {
return 0, err
}
return ts, nil
}
func (c *bgGarbageCollector) notifyPartitionGcByStreamingService(ctx context.Context, vchannels []string, partition *model.Partition) (uint64, error) {
req := &msgpb.DropPartitionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DropPartition),
commonpbutil.WithTimeStamp(0), // Timetick is given by streamingnode.
commonpbutil.WithSourceID(c.s.session.GetServerID()),
),
PartitionName: partition.PartitionName,
CollectionID: partition.CollectionID,
PartitionID: partition.PartitionID,
}
msg := message.NewDropPartitionMessageBuilderV1().
WithBroadcast(vchannels).
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: partition.CollectionID,
PartitionId: partition.PartitionID,
}).
WithBody(req).
MustBuildBroadcast()
r, err := streaming.WAL().Broadcast().Append(ctx, msg)
if err != nil {
return 0, err
}
maxTimeTick := uint64(0)
for _, r := range r.AppendResults {
if r.TimeTick > maxTimeTick {
maxTimeTick = r.TimeTick
}
}
return maxTimeTick, nil
}
func (c *bgGarbageCollector) GcCollectionData(ctx context.Context, coll *model.Collection) (ddlTs Timestamp, err error) {
c.s.ddlTsLockManager.Lock()
c.s.ddlTsLockManager.AddRefCnt(1)
defer c.s.ddlTsLockManager.AddRefCnt(-1)
defer c.s.ddlTsLockManager.Unlock()
ddlTs, err = c.notifyCollectionGc(ctx, coll)
if err != nil {
return 0, err
}
c.s.ddlTsLockManager.UpdateLastTs(ddlTs)
return ddlTs, nil
}
func (c *bgGarbageCollector) GcPartitionData(ctx context.Context, pChannels, vchannels []string, partition *model.Partition) (ddlTs Timestamp, err error) {
c.s.ddlTsLockManager.Lock()
c.s.ddlTsLockManager.AddRefCnt(1)
defer c.s.ddlTsLockManager.AddRefCnt(-1)
defer c.s.ddlTsLockManager.Unlock()
if streamingutil.IsStreamingServiceEnabled() {
notifier := snmanager.NewStreamingReadyNotifier()
if err := snmanager.StaticStreamingNodeManager.RegisterStreamingEnabledListener(ctx, notifier); err != nil {
return 0, err
}
if notifier.IsReady() {
// streaming service is ready, so we release the ready notifier and send it into streaming service.
notifier.Release()
ddlTs, err = c.notifyPartitionGcByStreamingService(ctx, vchannels, partition)
} else {
// streaming service is not ready, so we send it into msgstream with the notifier holding.
defer notifier.Release()
ddlTs, err = c.notifyPartitionGc(ctx, pChannels, partition)
}
} else {
ddlTs, err = c.notifyPartitionGc(ctx, pChannels, partition)
}
if err != nil {
return 0, err
}
c.s.ddlTsLockManager.UpdateLastTs(ddlTs)
return ddlTs, nil
}

View File

@ -1,583 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package rootcoord
import (
"context"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
mocktso "github.com/milvus-io/milvus/internal/tso/mocks"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func TestGarbageCollectorCtx_ReDropCollection(t *testing.T) {
oldValue := confirmGCInterval
defer func() {
confirmGCInterval = oldValue
}()
confirmGCInterval = 0
t.Run("failed to release collection", func(t *testing.T) {
broker := newMockBroker()
broker.ReleaseCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
return errors.New("error mock ReleaseCollection")
}
ticker := newTickerWithMockNormalStream()
core := newTestCore(withBroker(broker), withTtSynchronizer(ticker), withValidProxyManager())
gc := newBgGarbageCollector(core)
gc.ReDropCollection(&model.Collection{}, 1000)
})
t.Run("failed to DropCollectionIndex", func(t *testing.T) {
broker := newMockBroker()
releaseCollectionCalled := false
releaseCollectionChan := make(chan struct{}, 1)
broker.ReleaseCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
releaseCollectionCalled = true
releaseCollectionChan <- struct{}{}
return nil
}
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return errors.New("error mock DropCollectionIndex")
}
ticker := newTickerWithMockNormalStream()
core := newTestCore(withBroker(broker), withTtSynchronizer(ticker), withValidProxyManager())
gc := newBgGarbageCollector(core)
core.garbageCollector = gc
gc.ReDropCollection(&model.Collection{}, 1000)
<-releaseCollectionChan
assert.True(t, releaseCollectionCalled)
})
t.Run("failed to GcCollectionData", func(t *testing.T) {
broker := newMockBroker()
releaseCollectionCalled := false
releaseCollectionChan := make(chan struct{}, 1)
broker.ReleaseCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
releaseCollectionCalled = true
releaseCollectionChan <- struct{}{}
return nil
}
dropCollectionIndexCalled := false
dropCollectionIndexChan := make(chan struct{}, 1)
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
dropCollectionIndexCalled = true
dropCollectionIndexChan <- struct{}{}
return nil
}
ticker := newTickerWithMockFailStream() // failed to broadcast drop msg.
tsoAllocator := newMockTsoAllocator()
tsoAllocator.GenerateTSOF = func(count uint32) (uint64, error) {
return 100, nil
}
core := newTestCore(withBroker(broker), withTtSynchronizer(ticker), withTsoAllocator(tsoAllocator), withValidProxyManager())
core.ddlTsLockManager = newDdlTsLockManager(core.tsoAllocator)
gc := newBgGarbageCollector(core)
core.garbageCollector = gc
shardsNum := common.DefaultShardsNum
pchans := ticker.getDmlChannelNames(int(shardsNum))
gc.ReDropCollection(&model.Collection{PhysicalChannelNames: pchans}, 1000)
<-releaseCollectionChan
assert.True(t, releaseCollectionCalled)
<-dropCollectionIndexChan
assert.True(t, dropCollectionIndexCalled)
})
t.Run("failed to remove collection", func(t *testing.T) {
broker := newMockBroker()
releaseCollectionCalled := false
releaseCollectionChan := make(chan struct{}, 1)
broker.ReleaseCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
releaseCollectionCalled = true
releaseCollectionChan <- struct{}{}
return nil
}
gcConfirmCalled := false
gcConfirmChan := make(chan struct{})
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
gcConfirmCalled = true
close(gcConfirmChan)
return true
}
dropCollectionIndexCalled := false
dropCollectionIndexChan := make(chan struct{}, 1)
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
dropCollectionIndexCalled = true
dropCollectionIndexChan <- struct{}{}
return nil
}
dropMetaChan := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("RemoveCollection",
mock.Anything, // context.Context
mock.AnythingOfType("int64"),
mock.AnythingOfType("uint64")).
Run(func(args mock.Arguments) {
dropMetaChan <- struct{}{}
}).
Return(errors.New("error mock RemoveCollection"))
ticker := newTickerWithMockNormalStream()
tsoAllocator := newMockTsoAllocator()
tsoAllocator.GenerateTSOF = func(count uint32) (uint64, error) {
return 100, nil
}
core := newTestCore(withBroker(broker),
withTtSynchronizer(ticker),
withTsoAllocator(tsoAllocator),
withValidProxyManager(),
withMeta(meta))
core.ddlTsLockManager = newDdlTsLockManager(core.tsoAllocator)
gc := newBgGarbageCollector(core)
core.garbageCollector = gc
gc.ReDropCollection(&model.Collection{}, 1000)
<-releaseCollectionChan
assert.True(t, releaseCollectionCalled)
<-dropCollectionIndexChan
assert.True(t, dropCollectionIndexCalled)
<-gcConfirmChan
assert.True(t, gcConfirmCalled)
<-dropMetaChan
})
t.Run("normal case", func(t *testing.T) {
broker := newMockBroker()
releaseCollectionCalled := false
releaseCollectionChan := make(chan struct{}, 1)
broker.ReleaseCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
releaseCollectionCalled = true
releaseCollectionChan <- struct{}{}
return nil
}
dropCollectionIndexCalled := false
dropCollectionIndexChan := make(chan struct{}, 1)
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
dropCollectionIndexCalled = true
dropCollectionIndexChan <- struct{}{}
return nil
}
gcConfirmCalled := false
gcConfirmChan := make(chan struct{})
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
gcConfirmCalled = true
close(gcConfirmChan)
return true
}
meta := mockrootcoord.NewIMetaTable(t)
removeCollectionCalled := false
removeCollectionChan := make(chan struct{}, 1)
meta.On("RemoveCollection",
mock.Anything, // context.Context
mock.AnythingOfType("int64"),
mock.AnythingOfType("uint64")).
Return(func(ctx context.Context, collectionID int64, ts uint64) error {
removeCollectionCalled = true
removeCollectionChan <- struct{}{}
return nil
})
ticker := newTickerWithMockNormalStream()
tsoAllocator := newMockTsoAllocator()
tsoAllocator.GenerateTSOF = func(count uint32) (uint64, error) {
return 100, nil
}
core := newTestCore(withBroker(broker),
withTtSynchronizer(ticker),
withTsoAllocator(tsoAllocator),
withValidProxyManager(),
withMeta(meta))
core.ddlTsLockManager = newDdlTsLockManager(core.tsoAllocator)
gc := newBgGarbageCollector(core)
core.garbageCollector = gc
gc.ReDropCollection(&model.Collection{}, 1000)
<-releaseCollectionChan
assert.True(t, releaseCollectionCalled)
<-dropCollectionIndexChan
assert.True(t, dropCollectionIndexCalled)
<-removeCollectionChan
assert.True(t, removeCollectionCalled)
<-gcConfirmChan
assert.True(t, gcConfirmCalled)
})
}
func TestGarbageCollectorCtx_RemoveCreatingCollection(t *testing.T) {
t.Run("failed to UnwatchChannels", func(t *testing.T) {
defer cleanTestEnv()
shardNum := 2
ticker := newRocksMqTtSynchronizer()
pchans := ticker.getDmlChannelNames(shardNum)
tsoAllocator := mocktso.NewAllocator(t)
tsoAllocator.
On("GenerateTSO", mock.AnythingOfType("uint32")).
Return(Timestamp(0), errors.New("error mock GenerateTSO"))
executed := make(chan struct{}, 1)
executor := newMockStepExecutor()
executor.AddStepsFunc = func(s *stepStack) {
s.Execute(context.Background())
executed <- struct{}{}
}
core := newTestCore(withTtSynchronizer(ticker), withTsoAllocator(tsoAllocator), withStepExecutor(executor))
gc := newBgGarbageCollector(core)
core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator)
core.garbageCollector = gc
gc.RemoveCreatingCollection(&model.Collection{PhysicalChannelNames: pchans})
<-executed
})
t.Run("failed to RemoveCollection", func(t *testing.T) {
defer cleanTestEnv()
shardNum := 2
ticker := newRocksMqTtSynchronizer()
pchans := ticker.getDmlChannelNames(shardNum)
tsoAllocator := mocktso.NewAllocator(t)
tsoAllocator.
On("GenerateTSO", mock.AnythingOfType("uint32")).
Return(Timestamp(100), nil)
for _, pchan := range pchans {
ticker.syncedTtHistogram.update(pchan, 101)
}
removeCollectionCalled := false
removeCollectionChan := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("RemoveCollection",
mock.Anything,
mock.Anything,
mock.Anything,
).Return(func(ctx context.Context, collectionID UniqueID, ts Timestamp) error {
removeCollectionCalled = true
removeCollectionChan <- struct{}{}
return errors.New("error mock RemoveCollection")
})
core := newTestCore(withTtSynchronizer(ticker), withMeta(meta), withTsoAllocator(tsoAllocator))
gc := newBgGarbageCollector(core)
core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator)
core.garbageCollector = gc
gc.RemoveCreatingCollection(&model.Collection{PhysicalChannelNames: pchans})
<-removeCollectionChan
assert.True(t, removeCollectionCalled) // though it fail.
})
t.Run("normal case", func(t *testing.T) {
defer cleanTestEnv()
shardNum := 2
ticker := newRocksMqTtSynchronizer()
pchans := ticker.getDmlChannelNames(shardNum)
tsoAllocator := mocktso.NewAllocator(t)
tsoAllocator.
On("GenerateTSO", mock.AnythingOfType("uint32")).
Return(Timestamp(100), nil)
for _, pchan := range pchans {
ticker.syncedTtHistogram.update(pchan, 101)
}
removeCollectionCalled := false
removeCollectionChan := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("RemoveCollection",
mock.Anything,
mock.Anything,
mock.Anything,
).Return(func(ctx context.Context, collectionID UniqueID, ts Timestamp) error {
removeCollectionCalled = true
removeCollectionChan <- struct{}{}
return nil
})
core := newTestCore(withTtSynchronizer(ticker), withMeta(meta), withTsoAllocator(tsoAllocator))
gc := newBgGarbageCollector(core)
core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator)
core.garbageCollector = gc
gc.RemoveCreatingCollection(&model.Collection{PhysicalChannelNames: pchans})
<-removeCollectionChan
assert.True(t, removeCollectionCalled)
})
}
func TestGarbageCollectorCtx_ReDropPartition(t *testing.T) {
oldValue := confirmGCInterval
defer func() {
confirmGCInterval = oldValue
}()
confirmGCInterval = 0
t.Run("failed to GcPartitionData", func(t *testing.T) {
ticker := newTickerWithMockFailStream() // failed to broadcast drop msg.
shardsNum := int(common.DefaultShardsNum)
pchans := ticker.getDmlChannelNames(shardsNum)
tsoAllocator := newMockTsoAllocator()
tsoAllocator.GenerateTSOF = func(count uint32) (uint64, error) {
return 100, nil
}
core := newTestCore(withTtSynchronizer(ticker), withTsoAllocator(tsoAllocator), withDropIndex())
core.ddlTsLockManager = newDdlTsLockManager(core.tsoAllocator)
gc := newBgGarbageCollector(core)
core.garbageCollector = gc
gc.ReDropPartition(0, pchans, nil, &model.Partition{}, 100000)
})
t.Run("failed to RemovePartition", func(t *testing.T) {
ticker := newTickerWithMockNormalStream()
shardsNum := int(common.DefaultShardsNum)
pchans := ticker.getDmlChannelNames(shardsNum)
meta := mockrootcoord.NewIMetaTable(t)
removePartitionCalled := false
removePartitionChan := make(chan struct{}, 1)
meta.On("RemovePartition",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(func(ctx context.Context, dbID int64, collectionID int64, partitionID int64, ts uint64) error {
removePartitionCalled = true
removePartitionChan <- struct{}{}
return errors.New("error mock RemovePartition")
})
broker := newMockBroker()
gcConfirmCalled := false
gcConfirmChan := make(chan struct{})
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
gcConfirmCalled = true
close(gcConfirmChan)
return true
}
tsoAllocator := newMockTsoAllocator()
tsoAllocator.GenerateTSOF = func(count uint32) (uint64, error) {
return 100, nil
}
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker), withTsoAllocator(tsoAllocator), withDropIndex(), withBroker(broker))
core.ddlTsLockManager = newDdlTsLockManager(core.tsoAllocator)
gc := newBgGarbageCollector(core)
core.garbageCollector = gc
gc.ReDropPartition(0, pchans, nil, &model.Partition{}, 100000)
<-gcConfirmChan
assert.True(t, gcConfirmCalled)
<-removePartitionChan
assert.True(t, removePartitionCalled)
})
t.Run("normal case", func(t *testing.T) {
ticker := newTickerWithMockNormalStream()
shardsNum := int(common.DefaultShardsNum)
pchans := ticker.getDmlChannelNames(shardsNum)
removePartitionCalled := false
removePartitionChan := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
meta.On("RemovePartition",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(func(ctx context.Context, dbID int64, collectionID int64, partitionID int64, ts uint64) error {
removePartitionCalled = true
removePartitionChan <- struct{}{}
return nil
})
broker := newMockBroker()
gcConfirmCalled := false
gcConfirmChan := make(chan struct{})
broker.GCConfirmFunc = func(ctx context.Context, collectionID, partitionID UniqueID) bool {
gcConfirmCalled = true
close(gcConfirmChan)
return true
}
tsoAllocator := newMockTsoAllocator()
tsoAllocator.GenerateTSOF = func(count uint32) (uint64, error) {
return 100, nil
}
core := newTestCore(withMeta(meta), withTtSynchronizer(ticker), withTsoAllocator(tsoAllocator), withDropIndex(), withBroker(broker))
core.ddlTsLockManager = newDdlTsLockManager(core.tsoAllocator)
gc := newBgGarbageCollector(core)
core.garbageCollector = gc
gc.ReDropPartition(0, pchans, nil, &model.Partition{}, 100000)
<-gcConfirmChan
assert.True(t, gcConfirmCalled)
<-removePartitionChan
assert.True(t, removePartitionCalled)
})
}
func TestGarbageCollector_RemoveCreatingPartition(t *testing.T) {
t.Run("test normal", func(t *testing.T) {
defer cleanTestEnv()
ticker := newTickerWithMockNormalStream()
tsoAllocator := mocktso.NewAllocator(t)
signal := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().RemovePartition(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil).
Run(func(ctx context.Context, dbID, collectionID int64, partitionID int64, ts uint64) {
signal <- struct{}{}
})
qc := mocks.NewMixCoord(t)
qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(merr.Success(), nil)
core := newTestCore(withTtSynchronizer(ticker),
withMeta(meta),
withTsoAllocator(tsoAllocator),
withMixCoord(qc))
gc := newBgGarbageCollector(core)
core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator)
core.garbageCollector = gc
core.broker = newServerBroker(core)
gc.RemoveCreatingPartition(0, &model.Partition{}, 0)
<-signal
})
t.Run("test ReleasePartitions failed", func(t *testing.T) {
defer cleanTestEnv()
ticker := newTickerWithMockNormalStream()
tsoAllocator := mocktso.NewAllocator(t)
signal := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
qc := mocks.NewMixCoord(t)
qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).
Return(merr.Success(), errors.New("mock err")).
Run(func(ctx context.Context, req *querypb.ReleasePartitionsRequest) {
signal <- struct{}{}
})
core := newTestCore(withTtSynchronizer(ticker),
withMeta(meta),
withTsoAllocator(tsoAllocator),
withMixCoord(qc))
gc := newBgGarbageCollector(core)
core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator)
core.garbageCollector = gc
core.broker = newServerBroker(core)
gc.RemoveCreatingPartition(0, &model.Partition{}, 0)
<-signal
})
t.Run("test RemovePartition failed", func(t *testing.T) {
defer cleanTestEnv()
ticker := newTickerWithMockNormalStream()
tsoAllocator := mocktso.NewAllocator(t)
signal := make(chan struct{}, 1)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().RemovePartition(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(errors.New("mock err")).
Run(func(ctx context.Context, dbID, collectionID int64, partitionID int64, ts uint64) {
signal <- struct{}{}
})
qc := mocks.NewMixCoord(t)
qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(merr.Success(), nil)
core := newTestCore(withTtSynchronizer(ticker),
withMeta(meta),
withTsoAllocator(tsoAllocator),
withMixCoord(qc))
gc := newBgGarbageCollector(core)
core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator)
core.garbageCollector = gc
core.broker = newServerBroker(core)
gc.RemoveCreatingPartition(0, &model.Partition{}, 0)
<-signal
})
}
func TestGcPartitionData(t *testing.T) {
defer cleanTestEnv()
streamingutil.SetStreamingServiceEnabled()
defer streamingutil.UnsetStreamingServiceEnabled()
snmanager.ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, cb balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
})
b.EXPECT().RegisterStreamingEnabledNotifier(mock.Anything).Run(func(notifier *syncutil.AsyncTaskNotifier[struct{}]) {
notifier.Cancel()
})
balance.Register(b)
wal := mock_streaming.NewMockWALAccesser(t)
broadcast := mock_streaming.NewMockBroadcast(t)
broadcast.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{
BroadcastID: 0,
AppendResults: map[string]*types.AppendResult{
"ch-0": {},
},
}, nil)
wal.EXPECT().Broadcast().Return(broadcast)
streaming.SetWALForTest(wal)
tsoAllocator := mocktso.NewAllocator(t)
core := newTestCore()
gc := newBgGarbageCollector(core)
core.ddlTsLockManager = newDdlTsLockManager(tsoAllocator)
_, err := gc.GcPartitionData(context.Background(), nil, []string{"ch-0", "ch-1"}, &model.Partition{})
assert.NoError(t, err)
}

View File

@ -50,7 +50,13 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
var errIgnoredAlterAlias = errors.New("ignored alter alias") // alias already created on current collection, so it can be ignored.
var (
errIgnoredAlterAlias = errors.New("ignored alter alias") // alias already created on current collection, so it can be ignored.
errIgnoredCreateCollection = errors.New("ignored create collection") // create collection with same schema, so it can be ignored.
errIgnoerdCreatePartition = errors.New("ignored create partition") // partition is already exist, so it can be ignored.
errIgnoredDropCollection = errors.New("ignored drop collection") // drop collection or database not found, so it can be ignored.
errIgnoredDropPartition = errors.New("ignored drop partition") // drop partition not found, so it can be ignored.
)
type MetaTableChecker interface {
RBACChecker
@ -75,7 +81,7 @@ type IMetaTable interface {
AlterDatabase(ctx context.Context, newDB *model.Database, ts typeutil.Timestamp) error
AddCollection(ctx context.Context, coll *model.Collection) error
ChangeCollectionState(ctx context.Context, collectionID UniqueID, state pb.CollectionState, ts Timestamp) error
DropCollection(ctx context.Context, collectionID UniqueID, ts Timestamp) error
RemoveCollection(ctx context.Context, collectionID UniqueID, ts Timestamp) error
// GetCollectionID retrieves the corresponding collectionID based on the collectionName.
// If the collection does not exist, it will return InvalidCollectionID.
@ -93,7 +99,7 @@ type IMetaTable interface {
GetCollectionVirtualChannels(ctx context.Context, colID int64) []string
GetPChannelInfo(ctx context.Context, pchannel string) *rootcoordpb.GetPChannelInfoResponse
AddPartition(ctx context.Context, partition *model.Partition) error
ChangePartitionState(ctx context.Context, collectionID UniqueID, partitionID UniqueID, state pb.PartitionState, ts Timestamp) error
DropPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID, ts Timestamp) error
RemovePartition(ctx context.Context, dbID int64, collectionID UniqueID, partitionID UniqueID, ts Timestamp) error
// Alias
@ -497,22 +503,29 @@ func (mt *MetaTable) AddCollection(ctx context.Context, coll *model.Collection)
// Note:
// 1, idempotency check was already done outside;
// 2, no need to check time travel logic, since ts should always be the latest;
db, err := mt.getDatabaseByIDInternal(ctx, coll.DBID, typeutil.MaxTimestamp)
if err != nil {
return err
if coll.State != pb.CollectionState_CollectionCreated {
return fmt.Errorf("collection state should be created, collection name: %s, collection id: %d, state: %s", coll.Name, coll.CollectionID, coll.State)
}
if coll.State != pb.CollectionState_CollectionCreating {
return fmt.Errorf("collection state should be creating, collection name: %s, collection id: %d, state: %s", coll.Name, coll.CollectionID, coll.State)
// check if there's a collection meta with the same collection id.
// merge the collection meta together.
if _, ok := mt.collID2Meta[coll.CollectionID]; ok {
log.Ctx(ctx).Info("collection already created, skip add collection to meta table", zap.Int64("collectionID", coll.CollectionID))
return nil
}
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
if err := mt.catalog.CreateCollection(ctx1, coll, coll.CreateTime); err != nil {
return err
}
mt.collID2Meta[coll.CollectionID] = coll.Clone()
mt.names.insert(db.Name, coll.Name, coll.CollectionID)
mt.names.insert(coll.DBName, coll.Name, coll.CollectionID)
pn := coll.GetPartitionNum(true)
mt.generalCnt += pn * int(coll.ShardsNum)
metrics.RootCoordNumOfCollections.WithLabelValues(coll.DBName).Inc()
metrics.RootCoordNumOfPartitions.WithLabelValues().Add(float64(pn))
channel.StaticPChannelStatsManager.MustGet().AddVChannel(coll.VirtualChannelNames...)
log.Ctx(ctx).Info("add collection to meta table",
@ -524,7 +537,7 @@ func (mt *MetaTable) AddCollection(ctx context.Context, coll *model.Collection)
return nil
}
func (mt *MetaTable) ChangeCollectionState(ctx context.Context, collectionID UniqueID, state pb.CollectionState, ts Timestamp) error {
func (mt *MetaTable) DropCollection(ctx context.Context, collectionID UniqueID, ts Timestamp) error {
mt.ddLock.Lock()
defer mt.ddLock.Unlock()
@ -532,8 +545,12 @@ func (mt *MetaTable) ChangeCollectionState(ctx context.Context, collectionID Uni
if !ok {
return nil
}
if coll.State == pb.CollectionState_CollectionDropping {
return nil
}
clone := coll.Clone()
clone.State = state
clone.State = pb.CollectionState_CollectionDropping
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
if err := mt.catalog.AlterCollection(ctx1, coll, clone, metastore.MODIFY, ts, false); err != nil {
return err
@ -547,21 +564,13 @@ func (mt *MetaTable) ChangeCollectionState(ctx context.Context, collectionID Uni
pn := coll.GetPartitionNum(true)
switch state {
case pb.CollectionState_CollectionCreated:
mt.generalCnt += pn * int(coll.ShardsNum)
metrics.RootCoordNumOfCollections.WithLabelValues(db.Name).Inc()
metrics.RootCoordNumOfPartitions.WithLabelValues().Add(float64(pn))
case pb.CollectionState_CollectionDropping:
mt.generalCnt -= pn * int(coll.ShardsNum)
channel.StaticPChannelStatsManager.MustGet().RemoveVChannel(coll.VirtualChannelNames...)
metrics.RootCoordNumOfCollections.WithLabelValues(db.Name).Dec()
metrics.RootCoordNumOfPartitions.WithLabelValues().Sub(float64(pn))
}
log.Ctx(ctx).Info("change collection state", zap.Int64("collection", collectionID),
zap.String("state", state.String()), zap.Uint64("ts", ts))
mt.generalCnt -= pn * int(coll.ShardsNum)
channel.StaticPChannelStatsManager.MustGet().RemoveVChannel(coll.VirtualChannelNames...)
metrics.RootCoordNumOfCollections.WithLabelValues(db.Name).Dec()
metrics.RootCoordNumOfPartitions.WithLabelValues().Sub(float64(pn))
log.Ctx(ctx).Info("drop collection from meta table", zap.Int64("collection", collectionID),
zap.String("state", coll.State.String()), zap.Uint64("ts", ts))
return nil
}
@ -1042,9 +1051,18 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio
if !ok || !coll.Available() {
return fmt.Errorf("collection not exists: %d", partition.CollectionID)
}
if partition.State != pb.PartitionState_PartitionCreating {
if partition.State != pb.PartitionState_PartitionCreated {
return fmt.Errorf("partition state is not created, collection: %d, partition: %d, state: %s", partition.CollectionID, partition.PartitionID, partition.State)
}
// idempotency check here.
for _, part := range coll.Partitions {
if part.PartitionID == partition.PartitionID {
log.Ctx(ctx).Info("partition already exists, ignore the operation", zap.Int64("collection", partition.CollectionID), zap.Int64("partition", partition.PartitionID))
return nil
}
}
if err := mt.catalog.CreatePartition(ctx, coll.DBID, partition, partition.PartitionCreatedTimestamp); err != nil {
return err
}
@ -1053,11 +1071,14 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio
log.Ctx(ctx).Info("add partition to meta table",
zap.Int64("collection", partition.CollectionID), zap.String("partition", partition.PartitionName),
zap.Int64("partitionid", partition.PartitionID), zap.Uint64("ts", partition.PartitionCreatedTimestamp))
mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum
// support Dynamic load/release partitions
metrics.RootCoordNumOfPartitions.WithLabelValues().Inc()
return nil
}
func (mt *MetaTable) ChangePartitionState(ctx context.Context, collectionID UniqueID, partitionID UniqueID, state pb.PartitionState, ts Timestamp) error {
func (mt *MetaTable) DropPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID, ts Timestamp) error {
mt.ddLock.Lock()
defer mt.ddLock.Unlock()
@ -1067,32 +1088,29 @@ func (mt *MetaTable) ChangePartitionState(ctx context.Context, collectionID Uniq
}
for idx, part := range coll.Partitions {
if part.PartitionID == partitionID {
if part.State == pb.PartitionState_PartitionDropping {
// promise idempotency here.
return nil
}
clone := part.Clone()
clone.State = state
clone.State = pb.PartitionState_PartitionDropping
ctx1 := contextutil.WithTenantID(ctx, Params.CommonCfg.ClusterName.GetValue())
if err := mt.catalog.AlterPartition(ctx1, coll.DBID, part, clone, metastore.MODIFY, ts); err != nil {
return err
}
mt.collID2Meta[collectionID].Partitions[idx] = clone
switch state {
case pb.PartitionState_PartitionCreated:
mt.generalCnt += int(coll.ShardsNum) // 1 partition * shardNum
// support Dynamic load/release partitions
metrics.RootCoordNumOfPartitions.WithLabelValues().Inc()
case pb.PartitionState_PartitionDropping:
mt.generalCnt -= int(coll.ShardsNum) // 1 partition * shardNum
metrics.RootCoordNumOfPartitions.WithLabelValues().Dec()
}
log.Ctx(ctx).Info("change partition state", zap.Int64("collection", collectionID),
zap.Int64("partition", partitionID), zap.String("state", state.String()),
log.Ctx(ctx).Info("drop partition", zap.Int64("collection", collectionID),
zap.Int64("partition", partitionID),
zap.Uint64("ts", ts))
mt.generalCnt -= int(coll.ShardsNum) // 1 partition * shardNum
metrics.RootCoordNumOfPartitions.WithLabelValues().Dec()
return nil
}
}
return fmt.Errorf("partition not exist, collection: %d, partition: %d", collectionID, partitionID)
// partition not found, so promise idempotency here.
return nil
}
func (mt *MetaTable) RemovePartition(ctx context.Context, dbID int64, collectionID UniqueID, partitionID UniqueID, ts Timestamp) error {

View File

@ -1456,80 +1456,6 @@ func TestMetaTable_ListAllAvailCollections(t *testing.T) {
assert.Equal(t, 0, len(db3))
}
func TestMetaTable_ChangeCollectionState(t *testing.T) {
t.Run("not exist", func(t *testing.T) {
meta := &MetaTable{}
err := meta.ChangeCollectionState(context.TODO(), 100, pb.CollectionState_CollectionCreated, 100)
assert.NoError(t, err)
})
t.Run("failed to alter collection", func(t *testing.T) {
catalog := mocks.NewRootCoordCatalog(t)
catalog.On("AlterCollection",
mock.Anything, // context.Context
mock.Anything, // *model.Collection
mock.Anything, // *model.Collection
mock.Anything, // metastore.AlterType
mock.AnythingOfType("uint64"),
mock.Anything,
).Return(errors.New("error mock AlterCollection"))
meta := &MetaTable{
catalog: catalog,
collID2Meta: map[typeutil.UniqueID]*model.Collection{
100: {Name: "test", CollectionID: 100},
},
}
err := meta.ChangeCollectionState(context.TODO(), 100, pb.CollectionState_CollectionCreated, 1000)
assert.Error(t, err)
})
t.Run("not found dbID", func(t *testing.T) {
catalog := mocks.NewRootCoordCatalog(t)
catalog.On("AlterCollection",
mock.Anything, // context.Context
mock.Anything, // *model.Collection
mock.Anything, // *model.Collection
mock.Anything, // metastore.AlterType
mock.AnythingOfType("uint64"),
mock.Anything,
).Return(nil)
meta := &MetaTable{
catalog: catalog,
dbName2Meta: map[string]*model.Database{},
collID2Meta: map[typeutil.UniqueID]*model.Collection{
100: {Name: "test", CollectionID: 100, DBID: util.DefaultDBID},
},
}
err := meta.ChangeCollectionState(context.TODO(), 100, pb.CollectionState_CollectionCreated, 1000)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
catalog := mocks.NewRootCoordCatalog(t)
catalog.On("AlterCollection",
mock.Anything, // context.Context
mock.Anything, // *model.Collection
mock.Anything, // *model.Collection
mock.Anything, // metastore.AlterType
mock.AnythingOfType("uint64"),
mock.Anything,
).Return(nil)
meta := &MetaTable{
catalog: catalog,
dbName2Meta: map[string]*model.Database{
util.DefaultDBName: {Name: util.DefaultDBName, ID: util.DefaultDBID},
},
collID2Meta: map[typeutil.UniqueID]*model.Collection{
100: {Name: "test", CollectionID: 100, DBID: util.DefaultDBID},
},
}
err := meta.ChangeCollectionState(context.TODO(), 100, pb.CollectionState_CollectionCreated, 1000)
assert.NoError(t, err)
err = meta.ChangeCollectionState(context.TODO(), 100, pb.CollectionState_CollectionDropping, 1000)
assert.NoError(t, err)
})
}
func TestMetaTable_AddPartition(t *testing.T) {
t.Run("collection not available", func(t *testing.T) {
meta := &MetaTable{}
@ -1564,7 +1490,7 @@ func TestMetaTable_AddPartition(t *testing.T) {
100: {Name: "test", CollectionID: 100},
},
}
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreating})
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreated})
assert.Error(t, err)
})
@ -1582,7 +1508,7 @@ func TestMetaTable_AddPartition(t *testing.T) {
100: {Name: "test", CollectionID: 100},
},
}
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreating})
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreated})
assert.NoError(t, err)
})
}
@ -1870,76 +1796,6 @@ func TestMetaTable_RenameCollection(t *testing.T) {
})
}
func TestMetaTable_ChangePartitionState(t *testing.T) {
t.Run("collection not exist", func(t *testing.T) {
meta := &MetaTable{}
err := meta.ChangePartitionState(context.TODO(), 100, 500, pb.PartitionState_PartitionDropping, 1000)
assert.NoError(t, err)
})
t.Run("partition not exist", func(t *testing.T) {
meta := &MetaTable{
collID2Meta: map[typeutil.UniqueID]*model.Collection{
100: {Name: "test", CollectionID: 100},
},
}
err := meta.ChangePartitionState(context.TODO(), 100, 500, pb.PartitionState_PartitionDropping, 1000)
assert.Error(t, err)
})
t.Run("failed to alter partition", func(t *testing.T) {
catalog := mocks.NewRootCoordCatalog(t)
catalog.On("AlterPartition",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(errors.New("error mock AlterPartition"))
meta := &MetaTable{
catalog: catalog,
collID2Meta: map[typeutil.UniqueID]*model.Collection{
100: {
Name: "test", CollectionID: 100,
Partitions: []*model.Partition{
{CollectionID: 100, PartitionID: 500},
},
},
},
}
err := meta.ChangePartitionState(context.TODO(), 100, 500, pb.PartitionState_PartitionDropping, 1000)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
catalog := mocks.NewRootCoordCatalog(t)
catalog.On("AlterPartition",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(nil)
meta := &MetaTable{
catalog: catalog,
collID2Meta: map[typeutil.UniqueID]*model.Collection{
100: {
Name: "test", CollectionID: 100,
Partitions: []*model.Partition{
{CollectionID: 100, PartitionID: 500},
},
},
},
}
err := meta.ChangePartitionState(context.TODO(), 100, 500, pb.PartitionState_PartitionCreated, 1000)
assert.NoError(t, err)
err = meta.ChangePartitionState(context.TODO(), 100, 500, pb.PartitionState_PartitionDropping, 1000)
assert.NoError(t, err)
})
}
func TestMetaTable_CreateDatabase(t *testing.T) {
db := model.NewDatabase(1, "exist", pb.DatabaseState_DatabaseCreated, nil)
t.Run("database already exist", func(t *testing.T) {

View File

@ -31,11 +31,13 @@ import (
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/mocks/rootcoord/mock_tombstone"
"github.com/milvus-io/milvus/internal/tso"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
@ -407,9 +409,14 @@ func newMockProxy() *mockProxy {
}
func newTestCore(opts ...Opt) *Core {
tombstoneSweeper := mock_tombstone.NewMockTombstoneSweeper(common.NewEmptyMockT())
tombstoneSweeper.EXPECT().AddTombstone(mock.Anything).Return()
tombstoneSweeper.EXPECT().Close().Return()
c := &Core{
metricsRequest: metricsinfo.NewMetricsRequest(),
session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: TestRootCoordID}},
metricsRequest: metricsinfo.NewMetricsRequest(),
session: &sessionutil.Session{SessionRaw: sessionutil.SessionRaw{ServerID: TestRootCoordID}},
tombstoneSweeper: tombstoneSweeper,
}
executor := newMockStepExecutor()
executor.AddStepsFunc = func(s *stepStack) {
@ -737,6 +744,11 @@ func withValidMixCoord() Opt {
Status: merr.Success(),
}, nil,
)
mixc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything).Return(
&datapb.DropVirtualChannelResponse{
Status: merr.Success(),
}, nil,
)
mixc.EXPECT().Flush(mock.Anything, mock.Anything).Return(
&datapb.FlushResponse{
@ -750,7 +762,7 @@ func withValidMixCoord() Opt {
mixc.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(
merr.Success(), nil,
)
mixc.EXPECT().NotifyDropPartition(mock.Anything, mock.Anything, mock.Anything).Return(nil)
return withMixCoord(mixc)
}
@ -908,6 +920,23 @@ type mockBroker struct {
GCConfirmFunc func(ctx context.Context, collectionID, partitionID UniqueID) bool
}
func newValidMockBroker() *mockBroker {
broker := newMockBroker()
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return nil
}
broker.ReleaseCollectionFunc = func(ctx context.Context, collectionID UniqueID) error {
return nil
}
broker.ReleasePartitionsFunc = func(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
return nil
}
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return nil
}
return broker
}
func newMockBroker() *mockBroker {
return &mockBroker{}
}
@ -954,12 +983,6 @@ func withBroker(b Broker) Opt {
}
}
func withGarbageCollector(gc GarbageCollector) Opt {
return func(c *Core) {
c.garbageCollector = gc
}
}
func newMockFailStream() *msgstream.WastedMockMsgStream {
stream := msgstream.NewWastedMockMsgStream()
stream.BroadcastFunc = func(pack *msgstream.MsgPack) error {

View File

@ -1,292 +0,0 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mockrootcoord
import (
context "context"
model "github.com/milvus-io/milvus/internal/metastore/model"
mock "github.com/stretchr/testify/mock"
)
// GarbageCollector is an autogenerated mock type for the GarbageCollector type
type GarbageCollector struct {
mock.Mock
}
type GarbageCollector_Expecter struct {
mock *mock.Mock
}
func (_m *GarbageCollector) EXPECT() *GarbageCollector_Expecter {
return &GarbageCollector_Expecter{mock: &_m.Mock}
}
// GcCollectionData provides a mock function with given fields: ctx, coll
func (_m *GarbageCollector) GcCollectionData(ctx context.Context, coll *model.Collection) (uint64, error) {
ret := _m.Called(ctx, coll)
if len(ret) == 0 {
panic("no return value specified for GcCollectionData")
}
var r0 uint64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *model.Collection) (uint64, error)); ok {
return rf(ctx, coll)
}
if rf, ok := ret.Get(0).(func(context.Context, *model.Collection) uint64); ok {
r0 = rf(ctx, coll)
} else {
r0 = ret.Get(0).(uint64)
}
if rf, ok := ret.Get(1).(func(context.Context, *model.Collection) error); ok {
r1 = rf(ctx, coll)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GarbageCollector_GcCollectionData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GcCollectionData'
type GarbageCollector_GcCollectionData_Call struct {
*mock.Call
}
// GcCollectionData is a helper method to define mock.On call
// - ctx context.Context
// - coll *model.Collection
func (_e *GarbageCollector_Expecter) GcCollectionData(ctx interface{}, coll interface{}) *GarbageCollector_GcCollectionData_Call {
return &GarbageCollector_GcCollectionData_Call{Call: _e.mock.On("GcCollectionData", ctx, coll)}
}
func (_c *GarbageCollector_GcCollectionData_Call) Run(run func(ctx context.Context, coll *model.Collection)) *GarbageCollector_GcCollectionData_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*model.Collection))
})
return _c
}
func (_c *GarbageCollector_GcCollectionData_Call) Return(ddlTs uint64, err error) *GarbageCollector_GcCollectionData_Call {
_c.Call.Return(ddlTs, err)
return _c
}
func (_c *GarbageCollector_GcCollectionData_Call) RunAndReturn(run func(context.Context, *model.Collection) (uint64, error)) *GarbageCollector_GcCollectionData_Call {
_c.Call.Return(run)
return _c
}
// GcPartitionData provides a mock function with given fields: ctx, pChannels, vchannels, partition
func (_m *GarbageCollector) GcPartitionData(ctx context.Context, pChannels []string, vchannels []string, partition *model.Partition) (uint64, error) {
ret := _m.Called(ctx, pChannels, vchannels, partition)
if len(ret) == 0 {
panic("no return value specified for GcPartitionData")
}
var r0 uint64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, []string, []string, *model.Partition) (uint64, error)); ok {
return rf(ctx, pChannels, vchannels, partition)
}
if rf, ok := ret.Get(0).(func(context.Context, []string, []string, *model.Partition) uint64); ok {
r0 = rf(ctx, pChannels, vchannels, partition)
} else {
r0 = ret.Get(0).(uint64)
}
if rf, ok := ret.Get(1).(func(context.Context, []string, []string, *model.Partition) error); ok {
r1 = rf(ctx, pChannels, vchannels, partition)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GarbageCollector_GcPartitionData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GcPartitionData'
type GarbageCollector_GcPartitionData_Call struct {
*mock.Call
}
// GcPartitionData is a helper method to define mock.On call
// - ctx context.Context
// - pChannels []string
// - vchannels []string
// - partition *model.Partition
func (_e *GarbageCollector_Expecter) GcPartitionData(ctx interface{}, pChannels interface{}, vchannels interface{}, partition interface{}) *GarbageCollector_GcPartitionData_Call {
return &GarbageCollector_GcPartitionData_Call{Call: _e.mock.On("GcPartitionData", ctx, pChannels, vchannels, partition)}
}
func (_c *GarbageCollector_GcPartitionData_Call) Run(run func(ctx context.Context, pChannels []string, vchannels []string, partition *model.Partition)) *GarbageCollector_GcPartitionData_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].([]string), args[2].([]string), args[3].(*model.Partition))
})
return _c
}
func (_c *GarbageCollector_GcPartitionData_Call) Return(ddlTs uint64, err error) *GarbageCollector_GcPartitionData_Call {
_c.Call.Return(ddlTs, err)
return _c
}
func (_c *GarbageCollector_GcPartitionData_Call) RunAndReturn(run func(context.Context, []string, []string, *model.Partition) (uint64, error)) *GarbageCollector_GcPartitionData_Call {
_c.Call.Return(run)
return _c
}
// ReDropCollection provides a mock function with given fields: collMeta, ts
func (_m *GarbageCollector) ReDropCollection(collMeta *model.Collection, ts uint64) {
_m.Called(collMeta, ts)
}
// GarbageCollector_ReDropCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReDropCollection'
type GarbageCollector_ReDropCollection_Call struct {
*mock.Call
}
// ReDropCollection is a helper method to define mock.On call
// - collMeta *model.Collection
// - ts uint64
func (_e *GarbageCollector_Expecter) ReDropCollection(collMeta interface{}, ts interface{}) *GarbageCollector_ReDropCollection_Call {
return &GarbageCollector_ReDropCollection_Call{Call: _e.mock.On("ReDropCollection", collMeta, ts)}
}
func (_c *GarbageCollector_ReDropCollection_Call) Run(run func(collMeta *model.Collection, ts uint64)) *GarbageCollector_ReDropCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*model.Collection), args[1].(uint64))
})
return _c
}
func (_c *GarbageCollector_ReDropCollection_Call) Return() *GarbageCollector_ReDropCollection_Call {
_c.Call.Return()
return _c
}
func (_c *GarbageCollector_ReDropCollection_Call) RunAndReturn(run func(*model.Collection, uint64)) *GarbageCollector_ReDropCollection_Call {
_c.Run(run)
return _c
}
// ReDropPartition provides a mock function with given fields: dbID, pChannels, vchannels, partition, ts
func (_m *GarbageCollector) ReDropPartition(dbID int64, pChannels []string, vchannels []string, partition *model.Partition, ts uint64) {
_m.Called(dbID, pChannels, vchannels, partition, ts)
}
// GarbageCollector_ReDropPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReDropPartition'
type GarbageCollector_ReDropPartition_Call struct {
*mock.Call
}
// ReDropPartition is a helper method to define mock.On call
// - dbID int64
// - pChannels []string
// - vchannels []string
// - partition *model.Partition
// - ts uint64
func (_e *GarbageCollector_Expecter) ReDropPartition(dbID interface{}, pChannels interface{}, vchannels interface{}, partition interface{}, ts interface{}) *GarbageCollector_ReDropPartition_Call {
return &GarbageCollector_ReDropPartition_Call{Call: _e.mock.On("ReDropPartition", dbID, pChannels, vchannels, partition, ts)}
}
func (_c *GarbageCollector_ReDropPartition_Call) Run(run func(dbID int64, pChannels []string, vchannels []string, partition *model.Partition, ts uint64)) *GarbageCollector_ReDropPartition_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].([]string), args[2].([]string), args[3].(*model.Partition), args[4].(uint64))
})
return _c
}
func (_c *GarbageCollector_ReDropPartition_Call) Return() *GarbageCollector_ReDropPartition_Call {
_c.Call.Return()
return _c
}
func (_c *GarbageCollector_ReDropPartition_Call) RunAndReturn(run func(int64, []string, []string, *model.Partition, uint64)) *GarbageCollector_ReDropPartition_Call {
_c.Run(run)
return _c
}
// RemoveCreatingCollection provides a mock function with given fields: collMeta
func (_m *GarbageCollector) RemoveCreatingCollection(collMeta *model.Collection) {
_m.Called(collMeta)
}
// GarbageCollector_RemoveCreatingCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCreatingCollection'
type GarbageCollector_RemoveCreatingCollection_Call struct {
*mock.Call
}
// RemoveCreatingCollection is a helper method to define mock.On call
// - collMeta *model.Collection
func (_e *GarbageCollector_Expecter) RemoveCreatingCollection(collMeta interface{}) *GarbageCollector_RemoveCreatingCollection_Call {
return &GarbageCollector_RemoveCreatingCollection_Call{Call: _e.mock.On("RemoveCreatingCollection", collMeta)}
}
func (_c *GarbageCollector_RemoveCreatingCollection_Call) Run(run func(collMeta *model.Collection)) *GarbageCollector_RemoveCreatingCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*model.Collection))
})
return _c
}
func (_c *GarbageCollector_RemoveCreatingCollection_Call) Return() *GarbageCollector_RemoveCreatingCollection_Call {
_c.Call.Return()
return _c
}
func (_c *GarbageCollector_RemoveCreatingCollection_Call) RunAndReturn(run func(*model.Collection)) *GarbageCollector_RemoveCreatingCollection_Call {
_c.Run(run)
return _c
}
// RemoveCreatingPartition provides a mock function with given fields: dbID, partition, ts
func (_m *GarbageCollector) RemoveCreatingPartition(dbID int64, partition *model.Partition, ts uint64) {
_m.Called(dbID, partition, ts)
}
// GarbageCollector_RemoveCreatingPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCreatingPartition'
type GarbageCollector_RemoveCreatingPartition_Call struct {
*mock.Call
}
// RemoveCreatingPartition is a helper method to define mock.On call
// - dbID int64
// - partition *model.Partition
// - ts uint64
func (_e *GarbageCollector_Expecter) RemoveCreatingPartition(dbID interface{}, partition interface{}, ts interface{}) *GarbageCollector_RemoveCreatingPartition_Call {
return &GarbageCollector_RemoveCreatingPartition_Call{Call: _e.mock.On("RemoveCreatingPartition", dbID, partition, ts)}
}
func (_c *GarbageCollector_RemoveCreatingPartition_Call) Run(run func(dbID int64, partition *model.Partition, ts uint64)) *GarbageCollector_RemoveCreatingPartition_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].(*model.Partition), args[2].(uint64))
})
return _c
}
func (_c *GarbageCollector_RemoveCreatingPartition_Call) Return() *GarbageCollector_RemoveCreatingPartition_Call {
_c.Call.Return()
return _c
}
func (_c *GarbageCollector_RemoveCreatingPartition_Call) RunAndReturn(run func(int64, *model.Partition, uint64)) *GarbageCollector_RemoveCreatingPartition_Call {
_c.Run(run)
return _c
}
// NewGarbageCollector creates a new instance of GarbageCollector. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewGarbageCollector(t interface {
mock.TestingT
Cleanup(func())
}) *GarbageCollector {
mock := &GarbageCollector{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -5,9 +5,7 @@ package mockrootcoord
import (
context "context"
etcdpb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
messagespb "github.com/milvus-io/milvus/pkg/v2/proto/messagespb"
@ -379,105 +377,6 @@ func (_c *IMetaTable_BackupRBAC_Call) RunAndReturn(run func(context.Context, str
return _c
}
// ChangeCollectionState provides a mock function with given fields: ctx, collectionID, state, ts
func (_m *IMetaTable) ChangeCollectionState(ctx context.Context, collectionID int64, state etcdpb.CollectionState, ts uint64) error {
ret := _m.Called(ctx, collectionID, state, ts)
if len(ret) == 0 {
panic("no return value specified for ChangeCollectionState")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, etcdpb.CollectionState, uint64) error); ok {
r0 = rf(ctx, collectionID, state, ts)
} else {
r0 = ret.Error(0)
}
return r0
}
// IMetaTable_ChangeCollectionState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ChangeCollectionState'
type IMetaTable_ChangeCollectionState_Call struct {
*mock.Call
}
// ChangeCollectionState is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - state etcdpb.CollectionState
// - ts uint64
func (_e *IMetaTable_Expecter) ChangeCollectionState(ctx interface{}, collectionID interface{}, state interface{}, ts interface{}) *IMetaTable_ChangeCollectionState_Call {
return &IMetaTable_ChangeCollectionState_Call{Call: _e.mock.On("ChangeCollectionState", ctx, collectionID, state, ts)}
}
func (_c *IMetaTable_ChangeCollectionState_Call) Run(run func(ctx context.Context, collectionID int64, state etcdpb.CollectionState, ts uint64)) *IMetaTable_ChangeCollectionState_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(etcdpb.CollectionState), args[3].(uint64))
})
return _c
}
func (_c *IMetaTable_ChangeCollectionState_Call) Return(_a0 error) *IMetaTable_ChangeCollectionState_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_ChangeCollectionState_Call) RunAndReturn(run func(context.Context, int64, etcdpb.CollectionState, uint64) error) *IMetaTable_ChangeCollectionState_Call {
_c.Call.Return(run)
return _c
}
// ChangePartitionState provides a mock function with given fields: ctx, collectionID, partitionID, state, ts
func (_m *IMetaTable) ChangePartitionState(ctx context.Context, collectionID int64, partitionID int64, state etcdpb.PartitionState, ts uint64) error {
ret := _m.Called(ctx, collectionID, partitionID, state, ts)
if len(ret) == 0 {
panic("no return value specified for ChangePartitionState")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, int64, etcdpb.PartitionState, uint64) error); ok {
r0 = rf(ctx, collectionID, partitionID, state, ts)
} else {
r0 = ret.Error(0)
}
return r0
}
// IMetaTable_ChangePartitionState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ChangePartitionState'
type IMetaTable_ChangePartitionState_Call struct {
*mock.Call
}
// ChangePartitionState is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionID int64
// - state etcdpb.PartitionState
// - ts uint64
func (_e *IMetaTable_Expecter) ChangePartitionState(ctx interface{}, collectionID interface{}, partitionID interface{}, state interface{}, ts interface{}) *IMetaTable_ChangePartitionState_Call {
return &IMetaTable_ChangePartitionState_Call{Call: _e.mock.On("ChangePartitionState", ctx, collectionID, partitionID, state, ts)}
}
func (_c *IMetaTable_ChangePartitionState_Call) Run(run func(ctx context.Context, collectionID int64, partitionID int64, state etcdpb.PartitionState, ts uint64)) *IMetaTable_ChangePartitionState_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(int64), args[3].(etcdpb.PartitionState), args[4].(uint64))
})
return _c
}
func (_c *IMetaTable_ChangePartitionState_Call) Return(_a0 error) *IMetaTable_ChangePartitionState_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_ChangePartitionState_Call) RunAndReturn(run func(context.Context, int64, int64, etcdpb.PartitionState, uint64) error) *IMetaTable_ChangePartitionState_Call {
_c.Call.Return(run)
return _c
}
// CheckIfAddCredential provides a mock function with given fields: ctx, req
func (_m *IMetaTable) CheckIfAddCredential(ctx context.Context, req *internalpb.CredentialInfo) error {
ret := _m.Called(ctx, req)
@ -1484,6 +1383,54 @@ func (_c *IMetaTable_DropAlias_Call) RunAndReturn(run func(context.Context, mess
return _c
}
// DropCollection provides a mock function with given fields: ctx, collectionID, ts
func (_m *IMetaTable) DropCollection(ctx context.Context, collectionID int64, ts uint64) error {
ret := _m.Called(ctx, collectionID, ts)
if len(ret) == 0 {
panic("no return value specified for DropCollection")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) error); ok {
r0 = rf(ctx, collectionID, ts)
} else {
r0 = ret.Error(0)
}
return r0
}
// IMetaTable_DropCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropCollection'
type IMetaTable_DropCollection_Call struct {
*mock.Call
}
// DropCollection is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - ts uint64
func (_e *IMetaTable_Expecter) DropCollection(ctx interface{}, collectionID interface{}, ts interface{}) *IMetaTable_DropCollection_Call {
return &IMetaTable_DropCollection_Call{Call: _e.mock.On("DropCollection", ctx, collectionID, ts)}
}
func (_c *IMetaTable_DropCollection_Call) Run(run func(ctx context.Context, collectionID int64, ts uint64)) *IMetaTable_DropCollection_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(uint64))
})
return _c
}
func (_c *IMetaTable_DropCollection_Call) Return(_a0 error) *IMetaTable_DropCollection_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_DropCollection_Call) RunAndReturn(run func(context.Context, int64, uint64) error) *IMetaTable_DropCollection_Call {
_c.Call.Return(run)
return _c
}
// DropDatabase provides a mock function with given fields: ctx, dbName, ts
func (_m *IMetaTable) DropDatabase(ctx context.Context, dbName string, ts uint64) error {
ret := _m.Called(ctx, dbName, ts)
@ -1580,6 +1527,55 @@ func (_c *IMetaTable_DropGrant_Call) RunAndReturn(run func(context.Context, stri
return _c
}
// DropPartition provides a mock function with given fields: ctx, collectionID, partitionID, ts
func (_m *IMetaTable) DropPartition(ctx context.Context, collectionID int64, partitionID int64, ts uint64) error {
ret := _m.Called(ctx, collectionID, partitionID, ts)
if len(ret) == 0 {
panic("no return value specified for DropPartition")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, int64, uint64) error); ok {
r0 = rf(ctx, collectionID, partitionID, ts)
} else {
r0 = ret.Error(0)
}
return r0
}
// IMetaTable_DropPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropPartition'
type IMetaTable_DropPartition_Call struct {
*mock.Call
}
// DropPartition is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionID int64
// - ts uint64
func (_e *IMetaTable_Expecter) DropPartition(ctx interface{}, collectionID interface{}, partitionID interface{}, ts interface{}) *IMetaTable_DropPartition_Call {
return &IMetaTable_DropPartition_Call{Call: _e.mock.On("DropPartition", ctx, collectionID, partitionID, ts)}
}
func (_c *IMetaTable_DropPartition_Call) Run(run func(ctx context.Context, collectionID int64, partitionID int64, ts uint64)) *IMetaTable_DropPartition_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(int64), args[3].(uint64))
})
return _c
}
func (_c *IMetaTable_DropPartition_Call) Return(_a0 error) *IMetaTable_DropPartition_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *IMetaTable_DropPartition_Call) RunAndReturn(run func(context.Context, int64, int64, uint64) error) *IMetaTable_DropPartition_Call {
_c.Call.Return(run)
return _c
}
// DropPrivilegeGroup provides a mock function with given fields: ctx, groupName
func (_m *IMetaTable) DropPrivilegeGroup(ctx context.Context, groupName string) error {
ret := _m.Called(ctx, groupName)

View File

@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/metastore"
kvmetastore "github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/rootcoord/tombstone"
streamingcoord "github.com/milvus-io/milvus/internal/streamingcoord/server"
tso2 "github.com/milvus-io/milvus/internal/tso"
"github.com/milvus-io/milvus/internal/types"
@ -99,7 +100,6 @@ type Core struct {
scheduler IScheduler
broker Broker
ddlTsLockManager DdlTsLockManager
garbageCollector GarbageCollector
stepExecutor StepExecutor
metaKVCreator metaKVCreator
@ -129,6 +129,8 @@ type Core struct {
activateFunc func() error
metricsRequest *metricsinfo.MetricsRequest
tombstoneSweeper tombstone.TombstoneSweeper
}
// --------------------- function --------------------------
@ -439,7 +441,6 @@ func (c *Core) initInternal() error {
c.broker = newServerBroker(c)
c.ddlTsLockManager = newDdlTsLockManager(c.tsoAllocator)
c.garbageCollector = newBgGarbageCollector(c)
c.stepExecutor = newBgStepExecutor(c.ctx)
c.proxyWatcher = proxyutil.NewProxyWatcher(
@ -490,7 +491,6 @@ func (c *Core) Init() error {
RegisterDDLCallbacks(c)
})
log.Info("RootCoord init successfully")
return initError
}
@ -607,33 +607,23 @@ func (c *Core) restore(ctx context.Context) error {
return err
}
c.tombstoneSweeper = tombstone.NewTombstoneSweeper()
for _, db := range dbs {
colls, err := c.meta.ListCollections(ctx, db.Name, typeutil.MaxTimestamp, false)
if err != nil {
return err
}
// restore the tombstone into the tombstone sweeper.
for _, coll := range colls {
ts, err := c.tsoAllocator.GenerateTSO(1)
if err != nil {
return err
// CollectionCreating is a deprecated status,
// we cannot promise the coordinator handle it correctly, so just treat it as a tombstone.
if coll.State == pb.CollectionState_CollectionDropping || coll.State == pb.CollectionState_CollectionCreating {
c.tombstoneSweeper.AddTombstone(newCollectionTombstone(c.meta, c.broker, coll.CollectionID))
continue
}
if coll.Available() {
for _, part := range coll.Partitions {
switch part.State {
case pb.PartitionState_PartitionDropping:
go c.garbageCollector.ReDropPartition(coll.DBID, coll.PhysicalChannelNames, coll.VirtualChannelNames, part.Clone(), ts)
case pb.PartitionState_PartitionCreating:
go c.garbageCollector.RemoveCreatingPartition(coll.DBID, part.Clone(), ts)
default:
}
}
} else {
switch coll.State {
case pb.CollectionState_CollectionDropping:
go c.garbageCollector.ReDropCollection(coll.Clone(), ts)
case pb.CollectionState_CollectionCreating:
go c.garbageCollector.RemoveCreatingCollection(coll.Clone())
default:
for _, part := range coll.Partitions {
if part.State == pb.PartitionState_PartitionDropping || part.State == pb.PartitionState_PartitionCreating {
c.tombstoneSweeper.AddTombstone(newPartitionTombstone(c.meta, c.broker, coll.CollectionID, part.PartitionID))
}
}
}
@ -733,6 +723,9 @@ func (c *Core) GracefulStop() {
// Stop stops rootCoord.
func (c *Core) Stop() error {
c.UpdateStateCode(commonpb.StateCode_Abnormal)
if c.tombstoneSweeper != nil {
c.tombstoneSweeper.Close()
}
c.stopExecutor()
c.stopScheduler()
@ -903,49 +896,28 @@ func (c *Core) CreateCollection(ctx context.Context, in *milvuspb.CreateCollecti
if err := merr.CheckHealthy(c.GetStateCode()); err != nil {
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("CreateCollection", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("CreateCollection")
log.Ctx(ctx).Info("received request to create collection",
logger := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.String("name", in.GetCollectionName()),
zap.String("role", typeutil.RootCoordRole))
t := &createCollectionTask{
baseTask: newBaseTask(ctx, c),
Req: in,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Ctx(ctx).Info("failed to enqueue request to create collection",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()))
metrics.RootCoordDDLReqCounter.WithLabelValues("CreateCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Ctx(ctx).Info("failed to create collection",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()),
zap.Uint64("ts", t.GetTs()))
zap.String("collectionName", in.GetCollectionName()))
logger.Info("received request to create collection")
if err := c.broadcastCreateCollectionV1(ctx, in); err != nil {
if errors.Is(err, errIgnoredCreateCollection) {
logger.Info("create existed collection with same schema, ignore it")
metrics.RootCoordDDLReqCounter.WithLabelValues("CreateCollection", metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Info("failed to create collection", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("CreateCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("CreateCollection", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("CreateCollection").Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.RootCoordDDLReqLatencyInQueue.WithLabelValues("CreateCollection").Observe(float64(t.queueDur.Milliseconds()))
log.Ctx(ctx).Info("done to create collection",
zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetCollectionName()),
zap.Uint64("ts", t.GetTs()))
logger.Info("done to create collection")
return merr.Success(), nil
}
@ -1005,46 +977,28 @@ func (c *Core) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRe
if err := merr.CheckHealthy(c.GetStateCode()); err != nil {
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("DropCollection", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("DropCollection")
log.Ctx(ctx).Info("received request to drop collection",
zap.String("role", typeutil.RootCoordRole),
logger := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.String("name", in.GetCollectionName()))
logger.Info("received request to drop collection")
t := &dropCollectionTask{
baseTask: newBaseTask(ctx, c),
Req: in,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Ctx(ctx).Info("failed to enqueue request to drop collection", zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()))
metrics.RootCoordDDLReqCounter.WithLabelValues("DropCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Ctx(ctx).Info("failed to drop collection", zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("name", in.GetCollectionName()),
zap.Uint64("ts", t.GetTs()))
if err := c.broadcastDropCollectionV1(ctx, in); err != nil {
if errors.Is(err, errIgnoredDropCollection) {
logger.Info("drop collection that not found, ignore it")
metrics.RootCoordDDLReqCounter.WithLabelValues("DropCollection", metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Info("failed to drop collection", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("DropCollection", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("DropCollection", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("DropCollection").Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.RootCoordDDLReqLatencyInQueue.WithLabelValues("DropCollection").Observe(float64(t.queueDur.Milliseconds()))
log.Ctx(ctx).Info("done to drop collection", zap.String("role", typeutil.RootCoordRole),
zap.String("name", in.GetCollectionName()),
zap.Uint64("ts", t.GetTs()))
logger.Info("done to drop collection")
return merr.Success(), nil
}
@ -1485,52 +1439,29 @@ func (c *Core) CreatePartition(ctx context.Context, in *milvuspb.CreatePartition
if err := merr.CheckHealthy(c.GetStateCode()); err != nil {
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("CreatePartition", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("CreatePartition")
log.Ctx(ctx).Info("received request to create partition",
zap.String("role", typeutil.RootCoordRole),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()))
t := &createPartitionTask{
baseTask: newBaseTask(ctx, c),
Req: in,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Ctx(ctx).Info("failed to enqueue request to create partition",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()))
metrics.RootCoordDDLReqCounter.WithLabelValues("CreatePartition", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Ctx(ctx).Info("failed to create partition",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()),
zap.Uint64("ts", t.GetTs()))
logger := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.String("collectionName", in.GetCollectionName()),
zap.String("partitionName", in.GetPartitionName()))
logger.Info("received request to create partition")
if err := c.broadcastCreatePartition(ctx, in); err != nil {
if errors.Is(err, errIgnoerdCreatePartition) {
logger.Info("create partition that already exists, ignore it")
metrics.RootCoordDDLReqCounter.WithLabelValues("CreatePartition", metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Info("failed to create partition", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("CreatePartition", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("CreatePartition", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("CreatePartition").Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.RootCoordDDLReqLatencyInQueue.WithLabelValues("CreatePartition").Observe(float64(t.queueDur.Milliseconds()))
log.Ctx(ctx).Info("done to create partition",
zap.String("role", typeutil.RootCoordRole),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()),
zap.Uint64("ts", t.GetTs()))
logger.Info("done to create partition")
return merr.Success(), nil
}
@ -1542,48 +1473,26 @@ func (c *Core) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequ
metrics.RootCoordDDLReqCounter.WithLabelValues("DropPartition", metrics.TotalLabel).Inc()
tr := timerecord.NewTimeRecorder("DropPartition")
log.Ctx(ctx).Info("received request to drop partition",
zap.String("role", typeutil.RootCoordRole),
logger := log.Ctx(ctx).With(zap.String("role", typeutil.RootCoordRole),
zap.String("dbName", in.GetDbName()),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()))
logger.Info("received request to drop partition")
t := &dropPartitionTask{
baseTask: newBaseTask(ctx, c),
Req: in,
}
if err := c.scheduler.AddTask(t); err != nil {
log.Ctx(ctx).Info("failed to enqueue request to drop partition",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()))
metrics.RootCoordDDLReqCounter.WithLabelValues("DropPartition", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
if err := t.WaitToFinish(); err != nil {
log.Ctx(ctx).Info("failed to drop partition",
zap.String("role", typeutil.RootCoordRole),
zap.Error(err),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()),
zap.Uint64("ts", t.GetTs()))
if err := c.broadcastDropPartition(ctx, in); err != nil {
if errors.Is(err, errIgnoredDropPartition) {
logger.Info("drop partition that not found, ignore it")
metrics.RootCoordDDLReqCounter.WithLabelValues("DropPartition", metrics.SuccessLabel).Inc()
return merr.Success(), nil
}
logger.Warn("failed to drop partition", zap.Error(err))
metrics.RootCoordDDLReqCounter.WithLabelValues("DropPartition", metrics.FailLabel).Inc()
return merr.Status(err), nil
}
metrics.RootCoordDDLReqCounter.WithLabelValues("DropPartition", metrics.SuccessLabel).Inc()
metrics.RootCoordDDLReqLatency.WithLabelValues("DropPartition").Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.RootCoordDDLReqLatencyInQueue.WithLabelValues("DropPartition").Observe(float64(t.queueDur.Milliseconds()))
log.Ctx(ctx).Info("done to drop partition",
zap.String("role", typeutil.RootCoordRole),
zap.String("collection", in.GetCollectionName()),
zap.String("partition", in.GetPartitionName()),
zap.Uint64("ts", t.GetTs()))
logger.Info("done to drop partition")
return merr.Success(), nil
}

View File

@ -18,6 +18,7 @@ package rootcoord
import (
"context"
"fmt"
"math/rand"
"os"
"testing"
@ -26,17 +27,26 @@ import (
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/coordinator/snmanager"
"github.com/milvus-io/milvus/internal/distributed/streaming"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/metastore/kv/rootcoord"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_balancer"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/balance"
"github.com/milvus-io/milvus/internal/streamingcoord/server/balancer/channel"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/broadcast"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
@ -45,12 +55,13 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/walimplstest"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -62,10 +73,41 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
func initStreamingSystem() {
t := common.NewEmptyMockT()
func initStreamingSystemAndCore(t *testing.T) *Core {
kv, _ := kvfactory.GetEtcdAndPath()
path := funcutil.RandomString(10)
catalogKV := etcdkv.NewEtcdKV(kv, path)
ss, err := rootcoord.NewSuffixSnapshot(catalogKV, rootcoord.SnapshotsSep, path, rootcoord.SnapshotPrefix)
require.NoError(t, err)
testDB := newNameDb()
collID2Meta := make(map[typeutil.UniqueID]*model.Collection)
core := newTestCore(withHealthyCode(),
withMeta(&MetaTable{
catalog: rootcoord.NewCatalog(catalogKV, ss),
names: testDB,
aliases: newNameDb(),
dbName2Meta: make(map[string]*model.Database),
collID2Meta: collID2Meta,
}),
withValidMixCoord(),
withValidProxyManager(),
withValidIDAllocator(),
withBroker(newValidMockBroker()),
)
registry.ResetRegistration()
RegisterDDLCallbacks(core)
// TODO: we should merge all coordinator code into one package unit,
// so these mock code can be replaced with the real code.
registry.RegisterDropIndexV2AckCallback(func(ctx context.Context, result message.BroadcastResultDropIndexMessageV2) error {
return nil
})
registry.RegisterDropLoadConfigV2AckCallback(func(ctx context.Context, result message.BroadcastResultDropLoadConfigMessageV2) error {
return nil
})
wal := mock_streaming.NewMockWALAccesser(t)
wal.EXPECT().ControlChannel().Return(funcutil.GetControlChannel("by-dev-rootcoord-dml_0"))
wal.EXPECT().ControlChannel().Return(funcutil.GetControlChannel("by-dev-rootcoord-dml_0")).Maybe()
streaming.SetWALForTest(wal)
bapi := mock_broadcaster.NewMockBroadcastAPI(t)
@ -73,24 +115,44 @@ func initStreamingSystem() {
results := make(map[string]*message.AppendResult)
for _, vchannel := range msg.BroadcastHeader().VChannels {
results[vchannel] = &message.AppendResult{
MessageID: walimplstest.NewTestMessageID(1),
MessageID: rmq.NewRmqID(1),
TimeTick: tsoutil.ComposeTSByTime(time.Now(), 0),
LastConfirmedMessageID: walimplstest.NewTestMessageID(1),
LastConfirmedMessageID: rmq.NewRmqID(1),
}
}
registry.CallMessageAckCallback(context.Background(), msg, results)
retry.Do(context.Background(), func() error {
return registry.CallMessageAckCallback(context.Background(), msg, results)
}, retry.AttemptAlways())
return &types.BroadcastAppendResult{}, nil
})
bapi.EXPECT().Close().Return()
}).Maybe()
bapi.EXPECT().Close().Return().Maybe()
mb := mock_broadcaster.NewMockBroadcaster(t)
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bapi, nil).Maybe()
mb.EXPECT().Close().Return()
broadcast.Release()
mb.EXPECT().Close().Return().Maybe()
broadcast.ResetBroadcaster()
broadcast.Register(mb)
snmanager.ResetStreamingNodeManager()
b := mock_balancer.NewMockBalancer(t)
b.EXPECT().AllocVirtualChannels(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, param balancer.AllocVChannelParam) ([]string, error) {
vchannels := make([]string, 0, param.Num)
for i := 0; i < param.Num; i++ {
vchannels = append(vchannels, funcutil.GetVirtualChannel(fmt.Sprintf("%s-rootcoord-dml_%d_100v0", path, i), param.CollectionID, i))
}
return vchannels, nil
}).Maybe()
b.EXPECT().WatchChannelAssignments(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, callback balancer.WatchChannelAssignmentsCallback) error {
<-ctx.Done()
return ctx.Err()
}).Maybe()
b.EXPECT().Close().Return().Maybe()
balance.Register(b)
channel.ResetStaticPChannelStatsManager()
channel.RecoverPChannelStatsManager([]string{})
return core
}
func TestRootCoord_CreateDatabase(t *testing.T) {
@ -170,36 +232,6 @@ func TestRootCoord_CreateCollection(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to add task", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withInvalidScheduler())
ctx := context.Background()
resp, err := c.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to execute", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withTaskFailScheduler())
ctx := context.Background()
resp, err := c.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("normal case, everything is ok", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
}
func TestRootCoord_DropCollection(t *testing.T) {
@ -210,36 +242,6 @@ func TestRootCoord_DropCollection(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to add task", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withInvalidScheduler())
ctx := context.Background()
resp, err := c.DropCollection(ctx, &milvuspb.DropCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to execute", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withTaskFailScheduler())
ctx := context.Background()
resp, err := c.DropCollection(ctx, &milvuspb.DropCollectionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("normal case, everything is ok", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.DropCollection(ctx, &milvuspb.DropCollectionRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
}
func TestRootCoord_CreatePartition(t *testing.T) {
@ -250,36 +252,6 @@ func TestRootCoord_CreatePartition(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to add task", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withInvalidScheduler())
ctx := context.Background()
resp, err := c.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to execute", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withTaskFailScheduler())
ctx := context.Background()
resp, err := c.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("normal case, everything is ok", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
}
func TestRootCoord_DropPartition(t *testing.T) {
@ -290,36 +262,6 @@ func TestRootCoord_DropPartition(t *testing.T) {
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to add task", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withInvalidScheduler())
ctx := context.Background()
resp, err := c.DropPartition(ctx, &milvuspb.DropPartitionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("failed to execute", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withTaskFailScheduler())
ctx := context.Background()
resp, err := c.DropPartition(ctx, &milvuspb.DropPartitionRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
t.Run("normal case, everything is ok", func(t *testing.T) {
c := newTestCore(withHealthyCode(),
withValidScheduler())
ctx := context.Background()
resp, err := c.DropPartition(ctx, &milvuspb.DropPartitionRequest{})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode())
})
}
func TestRootCoord_CreateAlias(t *testing.T) {
@ -1734,25 +1676,6 @@ type RootCoordSuite struct {
func (s *RootCoordSuite) TestRestore() {
meta := mockrootcoord.NewIMetaTable(s.T())
gc := mockrootcoord.NewGarbageCollector(s.T())
finishCh := make(chan struct{}, 4)
gc.EXPECT().ReDropPartition(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
Run(func(args mock.Arguments) {
finishCh <- struct{}{}
})
gc.EXPECT().RemoveCreatingPartition(mock.Anything, mock.Anything, mock.Anything).Once().
Run(func(args mock.Arguments) {
finishCh <- struct{}{}
})
gc.EXPECT().ReDropCollection(mock.Anything, mock.Anything).Once().
Run(func(args mock.Arguments) {
finishCh <- struct{}{}
})
gc.EXPECT().RemoveCreatingCollection(mock.Anything).Once().
Run(func(args mock.Arguments) {
finishCh <- struct{}{}
})
meta.EXPECT().ListDatabases(mock.Anything, mock.Anything).
Return([]*model.Database{
@ -1813,16 +1736,9 @@ func (s *RootCoordSuite) TestRestore() {
return 100, nil
}
core := newTestCore(
withGarbageCollector(gc),
// withTtSynchronizer(ticker),
withTsoAllocator(tsoAllocator),
// withValidProxyManager(),
withMeta(meta))
core.restore(context.Background())
for i := 0; i < 4; i++ {
<-finishCh
}
}
func TestRootCoordSuite(t *testing.T) {

View File

@ -24,18 +24,15 @@ import (
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/util/proxyutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
pb "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
)
type stepPriority int
@ -65,132 +62,6 @@ func (s baseStep) Weight() stepPriority {
return stepPriorityLow
}
type addCollectionMetaStep struct {
baseStep
coll *model.Collection
}
func (s *addCollectionMetaStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.AddCollection(ctx, s.coll)
return nil, err
}
func (s *addCollectionMetaStep) Desc() string {
return fmt.Sprintf("add collection to meta table, name: %s, id: %d, ts: %d", s.coll.Name, s.coll.CollectionID, s.coll.CreateTime)
}
type deleteCollectionMetaStep struct {
baseStep
collectionID UniqueID
ts Timestamp
}
func (s *deleteCollectionMetaStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.RemoveCollection(ctx, s.collectionID, s.ts)
return nil, err
}
func (s *deleteCollectionMetaStep) Desc() string {
return fmt.Sprintf("delete collection from meta table, id: %d, ts: %d", s.collectionID, s.ts)
}
func (s *deleteCollectionMetaStep) Weight() stepPriority {
return stepPriorityNormal
}
type deleteDatabaseMetaStep struct {
baseStep
databaseName string
ts Timestamp
}
func (s *deleteDatabaseMetaStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.DropDatabase(ctx, s.databaseName, s.ts)
return nil, err
}
func (s *deleteDatabaseMetaStep) Desc() string {
return fmt.Sprintf("delete database from meta table, name: %s, ts: %d", s.databaseName, s.ts)
}
type removeDmlChannelsStep struct {
baseStep
pChannels []string
}
func (s *removeDmlChannelsStep) Execute(ctx context.Context) ([]nestedStep, error) {
s.core.chanTimeTick.removeDmlChannels(s.pChannels...)
return nil, nil
}
func (s *removeDmlChannelsStep) Desc() string {
// this shouldn't be called.
return fmt.Sprintf("remove dml channels: %v", s.pChannels)
}
func (s *removeDmlChannelsStep) Weight() stepPriority {
// avoid too frequent tt.
return stepPriorityUrgent
}
type watchChannelsStep struct {
baseStep
info *watchInfo
}
func (s *watchChannelsStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.WatchChannels(ctx, s.info)
return nil, err
}
func (s *watchChannelsStep) Desc() string {
return fmt.Sprintf("watch channels, ts: %d, collection: %d, partition: %d, vChannels: %v",
s.info.ts, s.info.collectionID, s.info.partitionID, s.info.vChannels)
}
type unwatchChannelsStep struct {
baseStep
collectionID UniqueID
channels collectionChannels
isSkip bool
}
func (s *unwatchChannelsStep) Execute(ctx context.Context) ([]nestedStep, error) {
unwatchByDropMsg := &deleteCollectionDataStep{
baseStep: baseStep{core: s.core},
coll: &model.Collection{CollectionID: s.collectionID, PhysicalChannelNames: s.channels.physicalChannels},
isSkip: s.isSkip,
}
return unwatchByDropMsg.Execute(ctx)
}
func (s *unwatchChannelsStep) Desc() string {
return fmt.Sprintf("unwatch channels, collection: %d, pChannels: %v, vChannels: %v",
s.collectionID, s.channels.physicalChannels, s.channels.virtualChannels)
}
func (s *unwatchChannelsStep) Weight() stepPriority {
return stepPriorityNormal
}
type changeCollectionStateStep struct {
baseStep
collectionID UniqueID
state pb.CollectionState
ts Timestamp
}
func (s *changeCollectionStateStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.ChangeCollectionState(ctx, s.collectionID, s.state, s.ts)
return nil, err
}
func (s *changeCollectionStateStep) Desc() string {
return fmt.Sprintf("change collection state, collection: %d, ts: %d, state: %s",
s.collectionID, s.ts, s.state.String())
}
type cleanupMetricsStep struct {
baseStep
dbName string
@ -227,56 +98,6 @@ func (s *expireCacheStep) Desc() string {
s.collectionID, s.collectionNames, s.ts)
}
type deleteCollectionDataStep struct {
baseStep
coll *model.Collection
isSkip bool
}
func (s *deleteCollectionDataStep) Execute(ctx context.Context) ([]nestedStep, error) {
if s.isSkip {
return nil, nil
}
if _, err := s.core.garbageCollector.GcCollectionData(ctx, s.coll); err != nil {
return nil, err
}
return nil, nil
}
func (s *deleteCollectionDataStep) Desc() string {
return fmt.Sprintf("delete collection data, collection: %d", s.coll.CollectionID)
}
func (s *deleteCollectionDataStep) Weight() stepPriority {
return stepPriorityImportant
}
type deletePartitionDataStep struct {
baseStep
pchans []string
vchans []string
partition *model.Partition
isSkip bool
}
func (s *deletePartitionDataStep) Execute(ctx context.Context) ([]nestedStep, error) {
if s.isSkip {
return nil, nil
}
_, err := s.core.garbageCollector.GcPartitionData(ctx, s.pchans, s.vchans, s.partition)
return nil, err
}
func (s *deletePartitionDataStep) Desc() string {
return fmt.Sprintf("delete partition data, collection: %d, partition: %d", s.partition.CollectionID, s.partition.PartitionID)
}
func (s *deletePartitionDataStep) Weight() stepPriority {
return stepPriorityImportant
}
type releaseCollectionStep struct {
baseStep
collectionID UniqueID
@ -334,104 +155,6 @@ func (s *dropIndexStep) Weight() stepPriority {
return stepPriorityNormal
}
type addPartitionMetaStep struct {
baseStep
partition *model.Partition
}
func (s *addPartitionMetaStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.AddPartition(ctx, s.partition)
return nil, err
}
func (s *addPartitionMetaStep) Desc() string {
return fmt.Sprintf("add partition to meta table, collection: %d, partition: %d", s.partition.CollectionID, s.partition.PartitionID)
}
type broadcastCreatePartitionMsgStep struct {
baseStep
vchannels []string
partition *model.Partition
ts Timestamp
}
func (s *broadcastCreatePartitionMsgStep) Execute(ctx context.Context) ([]nestedStep, error) {
req := &msgpb.CreatePartitionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_CreatePartition),
commonpbutil.WithTimeStamp(0), // ts is given by streamingnode.
),
PartitionName: s.partition.PartitionName,
CollectionID: s.partition.CollectionID,
PartitionID: s.partition.PartitionID,
}
msgs := make([]message.MutableMessage, 0, len(s.vchannels))
for _, vchannel := range s.vchannels {
msg, err := message.NewCreatePartitionMessageBuilderV1().
WithVChannel(vchannel).
WithHeader(&message.CreatePartitionMessageHeader{
CollectionId: s.partition.CollectionID,
PartitionId: s.partition.PartitionID,
}).
WithBody(req).
BuildMutable()
if err != nil {
return nil, err
}
msgs = append(msgs, msg)
}
if err := streaming.WAL().AppendMessagesWithOption(ctx, streaming.AppendOption{
BarrierTimeTick: s.ts,
}, msgs...).UnwrapFirstError(); err != nil {
return nil, err
}
return nil, nil
}
func (s *broadcastCreatePartitionMsgStep) Desc() string {
return fmt.Sprintf("broadcast create partition message to mq, collection: %d, partition: %d", s.partition.CollectionID, s.partition.PartitionID)
}
type changePartitionStateStep struct {
baseStep
collectionID UniqueID
partitionID UniqueID
state pb.PartitionState
ts Timestamp
}
func (s *changePartitionStateStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.ChangePartitionState(ctx, s.collectionID, s.partitionID, s.state, s.ts)
return nil, err
}
func (s *changePartitionStateStep) Desc() string {
return fmt.Sprintf("change partition step, collection: %d, partition: %d, state: %s, ts: %d",
s.collectionID, s.partitionID, s.state.String(), s.ts)
}
type removePartitionMetaStep struct {
baseStep
dbID UniqueID
collectionID UniqueID
partitionID UniqueID
ts Timestamp
}
func (s *removePartitionMetaStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.meta.RemovePartition(ctx, s.dbID, s.collectionID, s.partitionID, s.ts)
return nil, err
}
func (s *removePartitionMetaStep) Desc() string {
return fmt.Sprintf("remove partition meta, collection: %d, partition: %d, ts: %d", s.collectionID, s.partitionID, s.ts)
}
func (s *removePartitionMetaStep) Weight() stepPriority {
return stepPriorityNormal
}
type nullStep struct{}
func (s *nullStep) Execute(ctx context.Context) ([]nestedStep, error) {
@ -574,7 +297,7 @@ func (s *renameCollectionStep) Desc() string {
var (
confirmGCInterval = time.Minute * 20
allPartition UniqueID = -1
allPartition UniqueID = common.AllPartitionsID
)
type confirmGCStep struct {

View File

@ -22,11 +22,6 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
)
func restoreConfirmGCInterval() {
@ -79,41 +74,3 @@ func Test_confirmGCStep_Execute(t *testing.T) {
assert.NoError(t, err)
})
}
func TestSkip(t *testing.T) {
{
s := &unwatchChannelsStep{isSkip: true}
_, err := s.Execute(context.Background())
assert.NoError(t, err)
}
{
s := &deleteCollectionDataStep{isSkip: true}
_, err := s.Execute(context.Background())
assert.NoError(t, err)
}
{
s := &deletePartitionDataStep{isSkip: true}
_, err := s.Execute(context.Background())
assert.NoError(t, err)
}
}
func TestBroadcastCreatePartitionMsgStep(t *testing.T) {
wal := mock_streaming.NewMockWALAccesser(t)
wal.EXPECT().AppendMessagesWithOption(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{})
streaming.SetWALForTest(wal)
step := &broadcastCreatePartitionMsgStep{
baseStep: baseStep{core: nil},
vchannels: []string{"ch-0", "ch-1"},
partition: &model.Partition{
CollectionID: 1,
PartitionID: 2,
},
}
t.Logf("%v\n", step.Desc())
_, err := step.Execute(context.Background())
assert.NoError(t, err)
}

View File

@ -114,40 +114,6 @@ func TestGetLockerKey(t *testing.T) {
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true")
})
t.Run("create collection task locker key", func(t *testing.T) {
tt := &createCollectionTask{
Req: &milvuspb.CreateCollectionRequest{
DbName: "foo",
CollectionName: "bar",
},
collID: 10,
}
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|10-2-true")
})
t.Run("create partition task locker key", func(t *testing.T) {
metaMock := mockrootcoord.NewIMetaTable(t)
metaMock.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, s string, s2 string, u uint64) (*model.Collection, error) {
return &model.Collection{
Name: "real" + s2,
CollectionID: 111,
}, nil
})
c := &Core{
meta: metaMock,
}
tt := &createPartitionTask{
baseTask: baseTask{core: c},
Req: &milvuspb.CreatePartitionRequest{
DbName: "foo",
CollectionName: "bar",
PartitionName: "baz",
},
}
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true")
})
t.Run("describe collection task locker key", func(t *testing.T) {
metaMock := mockrootcoord.NewIMetaTable(t)
metaMock.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
@ -192,51 +158,6 @@ func TestGetLockerKey(t *testing.T) {
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false")
})
t.Run("drop collection task locker key", func(t *testing.T) {
metaMock := mockrootcoord.NewIMetaTable(t)
metaMock.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, s string, s2 string, u uint64) (*model.Collection, error) {
return &model.Collection{
Name: "bar",
CollectionID: 111,
}, nil
})
c := &Core{
meta: metaMock,
}
tt := &dropCollectionTask{
baseTask: baseTask{core: c},
Req: &milvuspb.DropCollectionRequest{
DbName: "foo",
CollectionName: "bar",
},
}
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true")
})
t.Run("drop partition task locker key", func(t *testing.T) {
metaMock := mockrootcoord.NewIMetaTable(t)
metaMock.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
RunAndReturn(func(ctx context.Context, s string, s2 string, u uint64) (*model.Collection, error) {
return &model.Collection{
Name: "real" + s2,
CollectionID: 111,
}, nil
})
c := &Core{
meta: metaMock,
}
tt := &dropPartitionTask{
baseTask: baseTask{core: c},
Req: &milvuspb.DropPartitionRequest{
DbName: "foo",
CollectionName: "bar",
PartitionName: "baz",
},
}
key := tt.GetLockerKey()
assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true")
})
t.Run("has collection task locker key", func(t *testing.T) {
tt := &hasCollectionTask{
Req: &milvuspb.HasCollectionRequest{

View File

@ -0,0 +1,121 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tombstone
import (
"context"
"time"
"go.uber.org/zap"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// NewTombstoneSweeper creates a new tombstone sweeper.
// It will start a background goroutine to sweep the tombstones periodically.
// Once the tombstone is safe to be removed, it will be removed by the background goroutine.
func NewTombstoneSweeper() TombstoneSweeper {
ts := &tombstoneSweeperImpl{
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
incoming: make(chan Tombstone),
tombstones: make(map[string]Tombstone),
interval: 5 * time.Minute,
}
ts.SetLogger(log.With(log.FieldModule(typeutil.RootCoordRole), log.FieldComponent("tombstone_sweeper")))
go ts.background()
return ts
}
// TombstoneSweeper is a sweeper for the tombstones.
type tombstoneSweeperImpl struct {
log.Binder
notifier *syncutil.AsyncTaskNotifier[struct{}]
incoming chan Tombstone
tombstones map[string]Tombstone
interval time.Duration
// TODO: add metrics for the tombstone sweeper.
}
// AddTombstone adds a tombstone to the sweeper.
func (s *tombstoneSweeperImpl) AddTombstone(tombstone Tombstone) {
select {
case <-s.notifier.Context().Done():
case s.incoming <- tombstone:
}
}
func (s *tombstoneSweeperImpl) background() {
defer func() {
s.notifier.Finish(struct{}{})
s.Logger().Info("tombstone sweeper background exit")
}()
s.Logger().Info("tombstone sweeper background start", zap.Duration("interval", s.interval))
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
for {
select {
case tombstone := <-s.incoming:
if _, ok := s.tombstones[tombstone.ID()]; !ok {
s.tombstones[tombstone.ID()] = tombstone
s.Logger().Info("tombstone added", zap.String("tombstone", tombstone.ID()))
}
case <-ticker.C:
s.triggerGCTombstone(s.notifier.Context())
case <-s.notifier.Context().Done():
return
}
}
}
// triggerGCTombstone triggers the garbage collection of the tombstones.
func (s *tombstoneSweeperImpl) triggerGCTombstone(ctx context.Context) {
if len(s.tombstones) == 0 {
return
}
for _, tombstone := range s.tombstones {
if ctx.Err() != nil {
// The tombstone sweeper is closing, stop it.
return
}
tombstoneID := tombstone.ID()
confirmed, err := tombstone.ConfirmCanBeRemoved(ctx)
if err != nil {
s.Logger().Warn("fail to confirm if tombstone can be removed", zap.String("tombstone", tombstoneID), zap.Error(err))
continue
}
if !confirmed {
continue
}
if err := tombstone.Remove(ctx); err != nil {
s.Logger().Warn("fail to remove tombstone", zap.String("tombstone", tombstoneID), zap.Error(err))
continue
}
delete(s.tombstones, tombstoneID)
s.Logger().Info("tombstone removed", zap.String("tombstone", tombstoneID))
}
}
// Close closes the tombstone sweeper.
func (s *tombstoneSweeperImpl) Close() {
s.notifier.Cancel()
s.notifier.BlockUntilFinish()
}

View File

@ -0,0 +1,93 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tombstone
import (
"context"
"math/rand"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
func TestTombstoneSweeper_AddTombstone(t *testing.T) {
sweeper := NewTombstoneSweeper()
sweeper.Close()
sweeperImpl := &tombstoneSweeperImpl{
notifier: syncutil.NewAsyncTaskNotifier[struct{}](),
incoming: make(chan Tombstone),
tombstones: make(map[string]Tombstone),
interval: 1 * time.Millisecond,
}
go sweeperImpl.background()
testTombstone := &testTombstoneImpl{
id: "test",
confirmed: atomic.NewBool(false),
canRemove: atomic.NewBool(false),
removed: atomic.NewBool(false),
}
sweeperImpl.AddTombstone(testTombstone)
time.Sleep(5 * time.Millisecond)
assert.False(t, testTombstone.removed.Load())
testTombstone.confirmed.Store(true)
time.Sleep(5 * time.Millisecond)
assert.False(t, testTombstone.removed.Load())
testTombstone.canRemove.Store(true)
assert.Eventually(t, func() bool {
return testTombstone.removed.Load()
}, 100*time.Millisecond, 10*time.Millisecond)
sweeperImpl.Close()
assert.Len(t, sweeperImpl.tombstones, 0)
}
type testTombstoneImpl struct {
id string
confirmed *atomic.Bool
canRemove *atomic.Bool
removed *atomic.Bool
}
func (t *testTombstoneImpl) ID() string {
return t.id
}
func (t *testTombstoneImpl) ConfirmCanBeRemoved(ctx context.Context) (bool, error) {
if rand.Intn(2) == 0 {
return false, errors.New("fail to confirm")
}
return t.confirmed.Load(), nil
}
func (t *testTombstoneImpl) Remove(ctx context.Context) error {
if !t.canRemove.Load() {
return errors.New("tombstone can not be removed")
}
t.removed.Store(true)
return nil
}

View File

@ -0,0 +1,36 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tombstone
import "context"
type TombstoneSweeper interface {
AddTombstone(tombstone Tombstone)
Close()
}
// Tombstone is the interface for the tombstone.
type Tombstone interface {
// ID returns the unique identifier of the tombstone.
ID() string
// ConfirmCanBeRemoved checks if the tombstone can be removed forever.
ConfirmCanBeRemoved(ctx context.Context) (bool, error)
// Remove removes the tombstone.
Remove(ctx context.Context) error
}

View File

@ -19,6 +19,7 @@ var (
)
type (
AllocVChannelParam = channel.AllocVChannelParam
WatchChannelAssignmentsCallbackParam = channel.WatchChannelAssignmentsCallbackParam
WatchChannelAssignmentsCallback = channel.WatchChannelAssignmentsCallback
)
@ -34,6 +35,9 @@ type Balancer interface {
// GetAllStreamingNodes fetches all streaming node info.
GetAllStreamingNodes(ctx context.Context) (map[int64]*types.StreamingNodeInfo, error)
// AllocVirtualChannels allocates virtual channels for a collection.
AllocVirtualChannels(ctx context.Context, param AllocVChannelParam) ([]string, error)
// UpdateBalancePolicy update the balance policy.
UpdateBalancePolicy(ctx context.Context, req *streamingpb.UpdateWALBalancePolicyRequest) (*streamingpb.UpdateWALBalancePolicyResponse, error)

View File

@ -138,6 +138,11 @@ func (b *balancerImpl) UpdateReplicateConfiguration(ctx context.Context, result
return nil
}
// AllocVirtualChannels allocates virtual channels for a collection.
func (b *balancerImpl) AllocVirtualChannels(ctx context.Context, param AllocVChannelParam) ([]string, error) {
return b.channelMetaManager.AllocVirtualChannels(ctx, param)
}
// UpdateBalancePolicy update the balance policy.
func (b *balancerImpl) UpdateBalancePolicy(ctx context.Context, req *types.UpdateWALBalancePolicyRequest) (*types.UpdateWALBalancePolicyResponse, error) {
if !b.lifetime.Add(typeutil.LifetimeStateWorking) {

View File

@ -2,6 +2,7 @@ package channel
import (
"context"
"sort"
"sync"
"github.com/cockroachdb/errors"
@ -25,6 +26,11 @@ import (
var ErrChannelNotExist = errors.New("channel not exist")
type (
AllocVChannelParam struct {
CollectionID int64
Num int
}
WatchChannelAssignmentsCallbackParam struct {
Version typeutil.VersionInt64Pair
CChannelAssignment *streamingpb.CChannelAssignment
@ -238,6 +244,49 @@ func (cm *ChannelManager) CurrentPChannelsView() *PChannelView {
return view
}
// AllocVirtualChannels allocates virtual channels for a collection.
func (cm *ChannelManager) AllocVirtualChannels(ctx context.Context, param AllocVChannelParam) ([]string, error) {
cm.cond.L.Lock()
defer cm.cond.L.Unlock()
if len(cm.channels) < param.Num {
return nil, errors.Errorf("not enough pchannels to allocate, expected: %d, got: %d", param.Num, len(cm.channels))
}
vchannels := make([]string, 0, param.Num)
for _, channel := range cm.sortChannelsByVChannelCount() {
if len(vchannels) >= param.Num {
break
}
vchannels = append(vchannels, funcutil.GetVirtualChannel(channel.id.Name, param.CollectionID, len(vchannels)))
}
return vchannels, nil
}
// withVChannelCount is a helper struct to sort the channels by the vchannel count.
type withVChannelCount struct {
id ChannelID
vchannelCount int
}
// sortChannelsByVChannelCount sorts the channels by the vchannel count.
func (cm *ChannelManager) sortChannelsByVChannelCount() []withVChannelCount {
vchannelCounts := make([]withVChannelCount, 0, len(cm.channels))
for id := range cm.channels {
vchannelCounts = append(vchannelCounts, withVChannelCount{
id: id,
vchannelCount: StaticPChannelStatsManager.Get().GetPChannelStats(id).VChannelCount(),
})
}
sort.Slice(vchannelCounts, func(i, j int) bool {
if vchannelCounts[i].vchannelCount == vchannelCounts[j].vchannelCount {
// make a stable sort result, so get the order of sort result with same vchannel count by name.
return vchannelCounts[i].id.Name < vchannelCounts[j].id.Name
}
return vchannelCounts[i].vchannelCount < vchannelCounts[j].vchannelCount
})
return vchannelCounts
}
// AssignPChannels update the pchannels to servers and return the modified pchannels.
// When the balancer want to assign a pchannel into a new server.
// It should always call this function to update the pchannel assignment first.

View File

@ -13,6 +13,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/util/streamingutil/util"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
@ -356,6 +357,48 @@ func TestChannelManager(t *testing.T) {
})
}
func TestAllocVirtualChannels(t *testing.T) {
ResetStaticPChannelStatsManager()
RecoverPChannelStatsManager([]string{})
catalog := mock_metastore.NewMockStreamingCoordCataLog(t)
resource.InitForTest(resource.OptStreamingCatalog(catalog))
// Test recover failure.
catalog.EXPECT().GetCChannel(mock.Anything).Return(&streamingpb.CChannelMeta{
Pchannel: "test-channel",
}, nil).Maybe()
catalog.EXPECT().GetVersion(mock.Anything).Return(nil, nil).Maybe()
catalog.EXPECT().SaveVersion(mock.Anything, mock.Anything).Return(nil).Maybe()
catalog.EXPECT().ListPChannel(mock.Anything).Return(nil, nil).Maybe()
catalog.EXPECT().GetReplicateConfiguration(mock.Anything).Return(nil, nil).Maybe()
ctx := context.Background()
newIncomingTopics := util.GetAllTopicsFromConfiguration()
m, err := RecoverChannelManager(ctx, newIncomingTopics.Collect()...)
assert.NoError(t, err)
assert.NotNil(t, m)
allocVChannels, err := m.AllocVirtualChannels(ctx, AllocVChannelParam{
CollectionID: 1,
Num: 256,
})
assert.Error(t, err)
assert.Nil(t, allocVChannels, 0)
StaticPChannelStatsManager.Get().AddVChannel("by-dev-rootcoord-dml_0_100v0", "by-dev-rootcoord-dml_0_101v0", "by-dev-rootcoord-dml_1_100v1")
allocVChannels, err = m.AllocVirtualChannels(ctx, AllocVChannelParam{
CollectionID: 1,
Num: 4,
})
assert.NoError(t, err)
assert.Len(t, allocVChannels, 4)
assert.Equal(t, allocVChannels[0], "by-dev-rootcoord-dml_10_1v0")
assert.Equal(t, allocVChannels[1], "by-dev-rootcoord-dml_11_1v1")
assert.Equal(t, allocVChannels[2], "by-dev-rootcoord-dml_12_1v2")
assert.Equal(t, allocVChannels[3], "by-dev-rootcoord-dml_13_1v3")
}
func TestStreamingEnableChecker(t *testing.T) {
ctx := context.Background()
ResetStaticPChannelStatsManager()

View File

@ -45,6 +45,18 @@ type assignmentSnapshot struct {
GlobalUnbalancedScore float64
}
// Clone will clone the assignment snapshot.
func (s *assignmentSnapshot) Clone() assignmentSnapshot {
assignments := make(map[types.ChannelID]types.PChannelInfoAssigned, len(s.Assignments))
for channelID, assignment := range s.Assignments {
assignments[channelID] = assignment
}
return assignmentSnapshot{
Assignments: assignments,
GlobalUnbalancedScore: s.GlobalUnbalancedScore,
}
}
// streamingNodeInfo is the streaming node info for vchannel fair policy.
type streamingNodeInfo struct {
AssignedVChannelCount int

View File

@ -256,3 +256,20 @@ func newLayout(channels map[string]int, vchannels map[string]map[string]int64, s
}
return layout
}
func TestAssignmentClone(t *testing.T) {
snapshot := assignmentSnapshot{
Assignments: map[types.ChannelID]types.PChannelInfoAssigned{
newChannelID("c1"): {
Channel: types.PChannelInfo{
Name: "c1",
},
},
},
}
clonedSnapshot := snapshot.Clone()
clonedSnapshot.Assignments[newChannelID("c2")] = types.PChannelInfoAssigned{}
assert.Len(t, snapshot.Assignments, 1)
assert.Equal(t, snapshot.Assignments[newChannelID("c1")], clonedSnapshot.Assignments[newChannelID("c1")])
assert.Len(t, clonedSnapshot.Assignments, 2)
}

View File

@ -94,7 +94,7 @@ func (p *policy) Balance(currentLayout balancer.CurrentLayout) (layout balancer.
// 4. Do a DFS to make a greatest snapshot.
// The DFS will find the unbalance score minimized assignment based on current layout.
greatestSnapshot := snapshot
greatestSnapshot := snapshot.Clone()
p.assignChannels(expectedLayout, reassignChannelIDs, &greatestSnapshot)
if greatestSnapshot.GlobalUnbalancedScore < snapshot.GlobalUnbalancedScore-p.cfg.RebalanceTolerance {
if p.Logger().Level().Enabled(zap.DebugLevel) {

View File

@ -9,5 +9,6 @@ import (
)
func ResetBroadcaster() {
Release()
singleton = syncutil.NewFuture[broadcaster.Broadcaster]()
}

View File

@ -416,7 +416,7 @@ func (b *broadcastTask) saveTaskIfDirty(ctx context.Context, logger *log.MLogger
logger = logger.With(zap.String("state", b.task.State.String()), zap.Int("ackedVChannelCount", ackedCount(b.task)))
if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.Header().BroadcastID, b.task); err != nil {
logger.Warn("save broadcast task failed", zap.Error(err))
if ctx.Err() != nil {
if ctx.Err() == nil {
panic("critical error: the save broadcast task is failed before the context is done")
}
return err

View File

@ -42,6 +42,7 @@ func TestAssignmentService(t *testing.T) {
<-ctx.Done()
return ctx.Err()
})
b.EXPECT().Close().Return().Maybe()
balance.Register(b)
// Set up the broadcaster
@ -54,6 +55,7 @@ func TestAssignmentService(t *testing.T) {
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(mba, nil).Maybe()
mb.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil).Maybe()
mb.EXPECT().LegacyAck(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe()
mb.EXPECT().Close().Return().Maybe()
broadcast.Register(mb)
// Test assignment discover

View File

@ -28,6 +28,8 @@ func (s *broadcastServceImpl) Broadcast(ctx context.Context, req *streamingpb.Br
if err != nil {
return nil, err
}
defer api.Close()
results, err := api.Broadcast(ctx, msg)
if err != nil {
return nil, err
@ -52,6 +54,8 @@ func (s *broadcastServceImpl) Ack(ctx context.Context, req *streamingpb.Broadcas
if err != nil {
return nil, err
}
// Once the ack is reached at streamingcoord, the ack operation should not be cancelable.
ctx = context.WithoutCancel(ctx)
if req.Message == nil {
// before 2.6.1, the request don't have the message field, only have the broadcast id and vchannel.
// so we need to use the legacy ack interface.

View File

@ -3,7 +3,9 @@ package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -23,12 +25,14 @@ func TestBroadcastService(t *testing.T) {
fb := syncutil.NewFuture[broadcaster.Broadcaster]()
mba := mock_broadcaster.NewMockBroadcastAPI(t)
mba.EXPECT().Close().Return()
mb := mock_broadcaster.NewMockBroadcaster(t)
fb.Set(mb)
mba.EXPECT().Broadcast(mock.Anything, mock.Anything).Return(&types.BroadcastAppendResult{}, nil)
mb.EXPECT().WithResourceKeys(mock.Anything, mock.Anything).Return(mba, nil)
mb.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil)
mb.EXPECT().LegacyAck(mock.Anything, mock.Anything, mock.Anything).Return(nil)
mb.EXPECT().Close().Return().Maybe()
broadcast.Register(mb)
msg := message.NewCreateCollectionMessageBuilderV1().
@ -54,4 +58,35 @@ func TestBroadcastService(t *testing.T) {
Properties: map[string]string{"key": "value"},
},
})
ctx, cancel := context.WithCancel(context.Background())
reached := make(chan struct{})
done := make(chan struct{})
mb.EXPECT().Ack(mock.Anything, mock.Anything).Unset()
mb.EXPECT().Ack(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, msg message.ImmutableMessage) error {
close(reached)
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
})
go func() {
<-reached
cancel()
time.Sleep(10 * time.Millisecond)
close(done)
}()
_, err := service.Ack(ctx, &streamingpb.BroadcastAckRequest{
BroadcastId: 1,
Vchannel: "v1",
Message: &commonpb.ImmutableMessage{
Id: walimplstest.NewTestMessageID(1).IntoProto(),
Payload: []byte("payload"),
Properties: map[string]string{"key": "value"},
},
})
assert.NoError(t, err)
}

Some files were not shown because too many files have changed in this diff Show More