From de13865769ab2276c2520c91ed3a95330c01974e Mon Sep 17 00:00:00 2001 From: SimFG Date: Thu, 23 Nov 2023 15:38:24 +0800 Subject: [PATCH] enhance: Add load/release partitions to replicate msg stream (#28399) /kind improvement issue: #25655 Signed-off-by: SimFG --- internal/proxy/impl.go | 2 + internal/proxy/task.go | 21 +++- internal/proxy/util.go | 10 ++ internal/proxy/util_test.go | 2 + pkg/mq/msgstream/msg_for_collection.go | 4 +- pkg/mq/msgstream/msg_for_collection_test.go | 6 + pkg/mq/msgstream/msg_for_database_test.go | 4 +- pkg/mq/msgstream/msg_for_index_test.go | 8 +- pkg/mq/msgstream/msg_for_partition.go | 132 ++++++++++++++++++++ pkg/mq/msgstream/msg_for_partition_test.go | 118 +++++++++++++++++ pkg/mq/msgstream/unmarshal.go | 4 + 11 files changed, 301 insertions(+), 10 deletions(-) create mode 100644 pkg/mq/msgstream/msg_for_partition.go create mode 100644 pkg/mq/msgstream/msg_for_partition_test.go diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 69a1006ba7..eb7a526be2 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -1286,6 +1286,7 @@ func (node *Proxy) LoadPartitions(ctx context.Context, request *milvuspb.LoadPar LoadPartitionsRequest: request, queryCoord: node.queryCoord, datacoord: node.dataCoord, + replicateMsgStream: node.replicateMsgStream, } log := log.Ctx(ctx).With( @@ -1351,6 +1352,7 @@ func (node *Proxy) ReleasePartitions(ctx context.Context, request *milvuspb.Rele Condition: NewTaskCondition(ctx), ReleasePartitionsRequest: request, queryCoord: node.queryCoord, + replicateMsgStream: node.replicateMsgStream, } method := "ReleasePartitions" diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 54934cd890..a1cfed1b8b 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1590,7 +1590,7 @@ func (t *releaseCollectionTask) Execute(ctx context.Context) (err error) { return err } SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleaseCollectionRequest) - return err + return nil } func (t *releaseCollectionTask) PostExecute(ctx context.Context) error { @@ -1606,7 +1606,8 @@ type loadPartitionsTask struct { datacoord types.DataCoordClient result *commonpb.Status - collectionID UniqueID + collectionID UniqueID + replicateMsgStream msgstream.MsgStream } func (t *loadPartitionsTask) TraceCtx() context.Context { @@ -1735,7 +1736,12 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { ResourceGroups: t.ResourceGroups, } t.result, err = t.queryCoord.LoadPartitions(ctx, request) - return err + if err != nil { + return err + } + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.LoadPartitionsRequest) + + return nil } func (t *loadPartitionsTask) PostExecute(ctx context.Context) error { @@ -1749,7 +1755,8 @@ type releasePartitionsTask struct { queryCoord types.QueryCoordClient result *commonpb.Status - collectionID UniqueID + collectionID UniqueID + replicateMsgStream msgstream.MsgStream } func (t *releasePartitionsTask) TraceCtx() context.Context { @@ -1836,7 +1843,11 @@ func (t *releasePartitionsTask) Execute(ctx context.Context) (err error) { PartitionIDs: partitionIDs, } t.result, err = t.queryCoord.ReleasePartitions(ctx, request) - return err + if err != nil { + return err + } + SendReplicateMessagePack(ctx, t.replicateMsgStream, t.ReleasePartitionsRequest) + return nil } func (t *releasePartitionsTask) PostExecute(ctx context.Context) error { diff --git a/internal/proxy/util.go b/internal/proxy/util.go index bde14d0034..f9d4a8cff9 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1561,6 +1561,16 @@ func SendReplicateMessagePack(ctx context.Context, replicateMsgStream msgstream. BaseMsg: getBaseMsg(ctx, ts), DropIndexRequest: *r, } + case *milvuspb.LoadPartitionsRequest: + tsMsg = &msgstream.LoadPartitionsMsg{ + BaseMsg: getBaseMsg(ctx, ts), + LoadPartitionsRequest: *r, + } + case *milvuspb.ReleasePartitionsRequest: + tsMsg = &msgstream.ReleasePartitionsMsg{ + BaseMsg: getBaseMsg(ctx, ts), + ReleasePartitionsRequest: *r, + } default: log.Warn("unknown request", zap.Any("request", request)) return diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 1bbc147f4d..94e15f8109 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -2085,5 +2085,7 @@ func TestSendReplicateMessagePack(t *testing.T) { SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleaseCollectionRequest{}) SendReplicateMessagePack(ctx, mockStream, &milvuspb.CreateIndexRequest{}) SendReplicateMessagePack(ctx, mockStream, &milvuspb.DropIndexRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.LoadPartitionsRequest{}) + SendReplicateMessagePack(ctx, mockStream, &milvuspb.ReleasePartitionsRequest{}) }) } diff --git a/pkg/mq/msgstream/msg_for_collection.go b/pkg/mq/msgstream/msg_for_collection.go index 4411684cce..a0fc13fd29 100644 --- a/pkg/mq/msgstream/msg_for_collection.go +++ b/pkg/mq/msgstream/msg_for_collection.go @@ -51,8 +51,8 @@ func (l *LoadCollectionMsg) SourceID() int64 { func (l *LoadCollectionMsg) Marshal(input TsMsg) (MarshalType, error) { loadCollectionMsg := input.(*LoadCollectionMsg) - createIndexRequest := &loadCollectionMsg.LoadCollectionRequest - mb, err := proto.Marshal(createIndexRequest) + loadCollectionRequest := &loadCollectionMsg.LoadCollectionRequest + mb, err := proto.Marshal(loadCollectionRequest) if err != nil { return nil, err } diff --git a/pkg/mq/msgstream/msg_for_collection_test.go b/pkg/mq/msgstream/msg_for_collection_test.go index 5f9f42a748..f84f17fefd 100644 --- a/pkg/mq/msgstream/msg_for_collection_test.go +++ b/pkg/mq/msgstream/msg_for_collection_test.go @@ -60,6 +60,8 @@ func TestFlushMsg(t *testing.T) { assert.EqualValues(t, 200, newMsg.ID()) assert.EqualValues(t, 1000, newMsg.BeginTs()) assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, 2, len(newMsg.(*FlushMsg).CollectionNames)) + assert.EqualValues(t, "unit_db", newMsg.(*FlushMsg).DbName) assert.True(t, msg.Size() > 0) } @@ -97,6 +99,8 @@ func TestLoadCollection(t *testing.T) { assert.EqualValues(t, 200, newMsg.ID()) assert.EqualValues(t, 1000, newMsg.BeginTs()) assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_db", newMsg.(*LoadCollectionMsg).DbName) + assert.EqualValues(t, "col1", newMsg.(*LoadCollectionMsg).CollectionName) assert.True(t, msg.Size() > 0) } @@ -134,6 +138,8 @@ func TestReleaseCollection(t *testing.T) { assert.EqualValues(t, 200, newMsg.ID()) assert.EqualValues(t, 1000, newMsg.BeginTs()) assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_db", newMsg.(*ReleaseCollectionMsg).DbName) + assert.EqualValues(t, "col1", newMsg.(*ReleaseCollectionMsg).CollectionName) assert.True(t, msg.Size() > 0) } diff --git a/pkg/mq/msgstream/msg_for_database_test.go b/pkg/mq/msgstream/msg_for_database_test.go index d7cfc80eeb..e3d9579599 100644 --- a/pkg/mq/msgstream/msg_for_database_test.go +++ b/pkg/mq/msgstream/msg_for_database_test.go @@ -50,7 +50,7 @@ func TestCreateDatabase(t *testing.T) { msgBytes, err := msg.Marshal(msg) assert.NoError(t, err) - var newMsg TsMsg = &ReleaseCollectionMsg{} + var newMsg TsMsg = &CreateDatabaseMsg{} _, err = newMsg.Unmarshal("1") assert.Error(t, err) @@ -59,6 +59,7 @@ func TestCreateDatabase(t *testing.T) { assert.EqualValues(t, 200, newMsg.ID()) assert.EqualValues(t, 1000, newMsg.BeginTs()) assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_db", newMsg.(*CreateDatabaseMsg).DbName) assert.True(t, msg.Size() > 0) } @@ -95,6 +96,7 @@ func TestDropDatabase(t *testing.T) { assert.EqualValues(t, 200, newMsg.ID()) assert.EqualValues(t, 1000, newMsg.BeginTs()) assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "unit_db", newMsg.(*DropDatabaseMsg).DbName) assert.True(t, msg.Size() > 0) } diff --git a/pkg/mq/msgstream/msg_for_index_test.go b/pkg/mq/msgstream/msg_for_index_test.go index cccc1e09b9..590231c163 100644 --- a/pkg/mq/msgstream/msg_for_index_test.go +++ b/pkg/mq/msgstream/msg_for_index_test.go @@ -74,7 +74,9 @@ func TestDropIndex(t *testing.T) { TargetID: 100000, ReplicateInfo: nil, }, - DbName: "unit_db", + DbName: "unit_db", + CollectionName: "col1", + IndexName: "unit_index", }, } assert.EqualValues(t, 100, msg.ID()) @@ -86,7 +88,7 @@ func TestDropIndex(t *testing.T) { msgBytes, err := msg.Marshal(msg) assert.NoError(t, err) - var newMsg TsMsg = &ReleaseCollectionMsg{} + var newMsg TsMsg = &DropIndexMsg{} _, err = newMsg.Unmarshal("1") assert.Error(t, err) @@ -95,6 +97,8 @@ func TestDropIndex(t *testing.T) { assert.EqualValues(t, 200, newMsg.ID()) assert.EqualValues(t, 1000, newMsg.BeginTs()) assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, "col1", newMsg.(*DropIndexMsg).CollectionName) + assert.EqualValues(t, "unit_index", newMsg.(*DropIndexMsg).IndexName) assert.True(t, msg.Size() > 0) } diff --git a/pkg/mq/msgstream/msg_for_partition.go b/pkg/mq/msgstream/msg_for_partition.go new file mode 100644 index 0000000000..6a3117fa5e --- /dev/null +++ b/pkg/mq/msgstream/msg_for_partition.go @@ -0,0 +1,132 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +type LoadPartitionsMsg struct { + BaseMsg + milvuspb.LoadPartitionsRequest +} + +var _ TsMsg = &LoadPartitionsMsg{} + +func (l *LoadPartitionsMsg) ID() UniqueID { + return l.Base.MsgID +} + +func (l *LoadPartitionsMsg) SetID(id UniqueID) { + l.Base.MsgID = id +} + +func (l *LoadPartitionsMsg) Type() MsgType { + return l.Base.MsgType +} + +func (l *LoadPartitionsMsg) SourceID() int64 { + return l.Base.SourceID +} + +func (l *LoadPartitionsMsg) Marshal(input TsMsg) (MarshalType, error) { + loadPartitionsMsg := input.(*LoadPartitionsMsg) + loadPartitionsRequest := &loadPartitionsMsg.LoadPartitionsRequest + mb, err := proto.Marshal(loadPartitionsRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (l *LoadPartitionsMsg) Unmarshal(input MarshalType) (TsMsg, error) { + loadPartitionsRequest := milvuspb.LoadPartitionsRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &loadPartitionsRequest) + if err != nil { + return nil, err + } + loadPartitionsMsg := &LoadPartitionsMsg{LoadPartitionsRequest: loadPartitionsRequest} + loadPartitionsMsg.BeginTimestamp = loadPartitionsMsg.GetBase().GetTimestamp() + loadPartitionsMsg.EndTimestamp = loadPartitionsMsg.GetBase().GetTimestamp() + + return loadPartitionsMsg, nil +} + +func (l *LoadPartitionsMsg) Size() int { + return proto.Size(&l.LoadPartitionsRequest) +} + +type ReleasePartitionsMsg struct { + BaseMsg + milvuspb.ReleasePartitionsRequest +} + +var _ TsMsg = &ReleasePartitionsMsg{} + +func (r *ReleasePartitionsMsg) ID() UniqueID { + return r.Base.MsgID +} + +func (r *ReleasePartitionsMsg) SetID(id UniqueID) { + r.Base.MsgID = id +} + +func (r *ReleasePartitionsMsg) Type() MsgType { + return r.Base.MsgType +} + +func (r *ReleasePartitionsMsg) SourceID() int64 { + return r.Base.SourceID +} + +func (r *ReleasePartitionsMsg) Marshal(input TsMsg) (MarshalType, error) { + releasePartitionsMsg := input.(*ReleasePartitionsMsg) + releasePartitionsRequest := &releasePartitionsMsg.ReleasePartitionsRequest + mb, err := proto.Marshal(releasePartitionsRequest) + if err != nil { + return nil, err + } + return mb, nil +} + +func (r *ReleasePartitionsMsg) Unmarshal(input MarshalType) (TsMsg, error) { + releasePartitionsRequest := milvuspb.ReleasePartitionsRequest{} + in, err := convertToByteArray(input) + if err != nil { + return nil, err + } + err = proto.Unmarshal(in, &releasePartitionsRequest) + if err != nil { + return nil, err + } + releasePartitionsMsg := &ReleasePartitionsMsg{ReleasePartitionsRequest: releasePartitionsRequest} + releasePartitionsMsg.BeginTimestamp = releasePartitionsMsg.GetBase().GetTimestamp() + releasePartitionsMsg.EndTimestamp = releasePartitionsMsg.GetBase().GetTimestamp() + return releasePartitionsMsg, nil +} + +func (r *ReleasePartitionsMsg) Size() int { + return proto.Size(&r.ReleasePartitionsRequest) +} diff --git a/pkg/mq/msgstream/msg_for_partition_test.go b/pkg/mq/msgstream/msg_for_partition_test.go new file mode 100644 index 0000000000..981be41bc6 --- /dev/null +++ b/pkg/mq/msgstream/msg_for_partition_test.go @@ -0,0 +1,118 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package msgstream + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" +) + +func TestLoadPartitions(t *testing.T) { + msg := &LoadPartitionsMsg{ + LoadPartitionsRequest: milvuspb.LoadPartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadPartitions, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + CollectionName: "col1", + PartitionNames: []string{ + "p1", + "p2", + }, + }, + } + + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_LoadPartitions, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &LoadPartitionsMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, 2, len(newMsg.(*LoadPartitionsMsg).PartitionNames)) + assert.EqualValues(t, "unit_db", newMsg.(*LoadPartitionsMsg).DbName) + assert.EqualValues(t, "col1", newMsg.(*LoadPartitionsMsg).CollectionName) + + assert.True(t, msg.Size() > 0) +} + +func TestReleasePartitions(t *testing.T) { + msg := &ReleasePartitionsMsg{ + ReleasePartitionsRequest: milvuspb.ReleasePartitionsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_ReleasePartitions, + MsgID: 100, + Timestamp: 1000, + SourceID: 10000, + TargetID: 100000, + ReplicateInfo: nil, + }, + DbName: "unit_db", + CollectionName: "col1", + PartitionNames: []string{ + "p1", + "p2", + }, + }, + } + + assert.EqualValues(t, 100, msg.ID()) + msg.SetID(200) + assert.EqualValues(t, 200, msg.ID()) + assert.Equal(t, commonpb.MsgType_ReleasePartitions, msg.Type()) + assert.EqualValues(t, 10000, msg.SourceID()) + + msgBytes, err := msg.Marshal(msg) + assert.NoError(t, err) + + var newMsg TsMsg = &ReleasePartitionsMsg{} + _, err = newMsg.Unmarshal("1") + assert.Error(t, err) + + newMsg, err = newMsg.Unmarshal(msgBytes) + assert.NoError(t, err) + assert.EqualValues(t, 200, newMsg.ID()) + assert.EqualValues(t, 1000, newMsg.BeginTs()) + assert.EqualValues(t, 1000, newMsg.EndTs()) + assert.EqualValues(t, 2, len(newMsg.(*ReleasePartitionsMsg).PartitionNames)) + assert.EqualValues(t, "unit_db", newMsg.(*ReleasePartitionsMsg).DbName) + assert.EqualValues(t, "col1", newMsg.(*ReleasePartitionsMsg).CollectionName) + + assert.True(t, msg.Size() > 0) +} diff --git a/pkg/mq/msgstream/unmarshal.go b/pkg/mq/msgstream/unmarshal.go index 31cee49d8b..80d1e0da0a 100644 --- a/pkg/mq/msgstream/unmarshal.go +++ b/pkg/mq/msgstream/unmarshal.go @@ -69,6 +69,8 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { loadCollectionMsg := LoadCollectionMsg{} releaseCollectionMsg := ReleaseCollectionMsg{} flushMsg := FlushMsg{} + loadPartitionsMsg := LoadPartitionsMsg{} + releasePartitionsMsg := ReleasePartitionsMsg{} createDatabaseMsg := CreateDatabaseMsg{} dropDatabaseMsg := DropDatabaseMsg{} @@ -87,6 +89,8 @@ func (pudf *ProtoUDFactory) NewUnmarshalDispatcher() *ProtoUnmarshalDispatcher { p.TempMap[commonpb.MsgType_DropIndex] = dropIndexMsg.Unmarshal p.TempMap[commonpb.MsgType_LoadCollection] = loadCollectionMsg.Unmarshal p.TempMap[commonpb.MsgType_ReleaseCollection] = releaseCollectionMsg.Unmarshal + p.TempMap[commonpb.MsgType_LoadPartitions] = loadPartitionsMsg.Unmarshal + p.TempMap[commonpb.MsgType_ReleasePartitions] = releasePartitionsMsg.Unmarshal p.TempMap[commonpb.MsgType_Flush] = flushMsg.Unmarshal p.TempMap[commonpb.MsgType_CreateDatabase] = createDatabaseMsg.Unmarshal p.TempMap[commonpb.MsgType_DropDatabase] = dropDatabaseMsg.Unmarshal