From bf6405ca938705dcbaa965dbef05ecfd197b18ef Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Fri, 17 Dec 2021 14:41:33 +0800 Subject: [PATCH] Simplify tSafe in query node (#13241) Signed-off-by: bigsheeper --- internal/querynode/data_sync_service.go | 15 -- internal/querynode/data_sync_service_test.go | 9 +- internal/querynode/flow_graph_query_node.go | 5 +- .../querynode/flow_graph_query_node_test.go | 4 +- .../querynode/flow_graph_service_time_node.go | 16 +- .../flow_graph_service_time_node_test.go | 7 +- internal/querynode/historical_test.go | 12 +- internal/querynode/mock_test.go | 2 +- internal/querynode/plan_test.go | 2 +- internal/querynode/query_collection_test.go | 22 ++- internal/querynode/query_node.go | 2 +- internal/querynode/query_node_test.go | 2 +- internal/querynode/query_service_test.go | 2 +- internal/querynode/segment_test.go | 4 +- internal/querynode/streaming_test.go | 16 +- internal/querynode/task.go | 65 ++++--- internal/querynode/tsafe.go | 158 +++--------------- internal/querynode/tsafe_replica.go | 88 +++------- internal/querynode/tsafe_replica_test.go | 71 ++++---- internal/querynode/tsafe_test.go | 72 +++----- 20 files changed, 165 insertions(+), 409 deletions(-) diff --git a/internal/querynode/data_sync_service.go b/internal/querynode/data_sync_service.go index e1327f5835..84946fc403 100644 --- a/internal/querynode/data_sync_service.go +++ b/internal/querynode/data_sync_service.go @@ -91,7 +91,6 @@ func (dsService *dataSyncService) addCollectionDeltaFlowGraph(collectionID Uniqu // collection flow graph doesn't need partition id partitionID := UniqueID(0) newFlowGraph := newQueryNodeDeltaFlowGraph(dsService.ctx, - loadTypeCollection, collectionID, partitionID, dsService.historicalReplica, @@ -272,20 +271,6 @@ func (dsService *dataSyncService) startPartitionFlowGraph(partitionID UniqueID, func (dsService *dataSyncService) removePartitionFlowGraph(partitionID UniqueID) { dsService.mu.Lock() defer dsService.mu.Unlock() - - if _, ok := dsService.partitionFlowGraphs[partitionID]; ok { - for channel, nodeFG := range dsService.partitionFlowGraphs[partitionID] { - // close flow graph - nodeFG.close() - // remove tSafe record - // no tSafe in tSafeReplica, don't return error - err := dsService.tSafeReplica.removeRecord(channel, partitionID) - if err != nil { - log.Warn(err.Error()) - } - } - dsService.partitionFlowGraphs[partitionID] = nil - } delete(dsService.partitionFlowGraphs, partitionID) } diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index c2e4dabddc..51e3c14d80 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -146,7 +146,7 @@ func TestDataSyncService_collectionFlowGraphs(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() dataSyncService := newDataSyncService(ctx, streaming, historicalReplica, tSafe, fac) assert.NotNil(t, dataSyncService) @@ -193,7 +193,7 @@ func TestDataSyncService_partitionFlowGraphs(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() dataSyncService := newDataSyncService(ctx, streaming, historicalReplica, tSafe, fac) assert.NotNil(t, dataSyncService) @@ -242,7 +242,7 @@ func TestDataSyncService_removePartitionFlowGraphs(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() tSafe.addTSafe(defaultVChannel) dataSyncService := newDataSyncService(ctx, streaming, historicalReplica, tSafe, fac) @@ -250,8 +250,7 @@ func TestDataSyncService_removePartitionFlowGraphs(t *testing.T) { dataSyncService.addPartitionFlowGraph(defaultPartitionID, defaultPartitionID, []Channel{defaultVChannel}) - isRemoved := dataSyncService.tSafeReplica.removeTSafe(defaultVChannel) - assert.True(t, isRemoved) + dataSyncService.tSafeReplica.removeTSafe(defaultVChannel) dataSyncService.removePartitionFlowGraph(defaultPartitionID) }) } diff --git a/internal/querynode/flow_graph_query_node.go b/internal/querynode/flow_graph_query_node.go index a205c46022..5dccf460c3 100644 --- a/internal/querynode/flow_graph_query_node.go +++ b/internal/querynode/flow_graph_query_node.go @@ -63,7 +63,7 @@ func newQueryNodeFlowGraph(ctx context.Context, var dmStreamNode node = q.newDmInputNode(ctx1, factory) var filterDmNode node = newFilteredDmNode(streamingReplica, loadType, collectionID, partitionID) var insertNode node = newInsertNode(streamingReplica) - var serviceTimeNode node = newServiceTimeNode(ctx1, tSafeReplica, loadType, collectionID, partitionID, channel, factory) + var serviceTimeNode node = newServiceTimeNode(ctx1, tSafeReplica, loadType, channel, factory) q.flowGraph.AddNode(dmStreamNode) q.flowGraph.AddNode(filterDmNode) @@ -110,7 +110,6 @@ func newQueryNodeFlowGraph(ctx context.Context, } func newQueryNodeDeltaFlowGraph(ctx context.Context, - loadType loadType, collectionID UniqueID, partitionID UniqueID, historicalReplica ReplicaInterface, @@ -132,7 +131,7 @@ func newQueryNodeDeltaFlowGraph(ctx context.Context, var dmStreamNode node = q.newDmInputNode(ctx1, factory) var filterDeleteNode node = newFilteredDeleteNode(historicalReplica, collectionID, partitionID) var deleteNode node = newDeleteNode(historicalReplica) - var serviceTimeNode node = newServiceTimeNode(ctx1, tSafeReplica, loadTypeCollection, collectionID, partitionID, channel, factory) + var serviceTimeNode node = newServiceTimeNode(ctx1, tSafeReplica, loadTypeCollection, channel, factory) q.flowGraph.AddNode(dmStreamNode) q.flowGraph.AddNode(filterDeleteNode) diff --git a/internal/querynode/flow_graph_query_node_test.go b/internal/querynode/flow_graph_query_node_test.go index 484482777e..f45235f5a5 100644 --- a/internal/querynode/flow_graph_query_node_test.go +++ b/internal/querynode/flow_graph_query_node_test.go @@ -29,7 +29,7 @@ func TestQueryNodeFlowGraph_consumerFlowGraph(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streamingReplica, err := genSimpleReplica() assert.NoError(t, err) @@ -62,7 +62,7 @@ func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() fg := newQueryNodeFlowGraph(ctx, loadTypeCollection, diff --git a/internal/querynode/flow_graph_service_time_node.go b/internal/querynode/flow_graph_service_time_node.go index 7c0c5947e7..4baadfc4e7 100644 --- a/internal/querynode/flow_graph_service_time_node.go +++ b/internal/querynode/flow_graph_service_time_node.go @@ -29,8 +29,6 @@ import ( type serviceTimeNode struct { baseNode loadType loadType - collectionID UniqueID - partitionID UniqueID vChannel Channel tSafeReplica TSafeReplicaInterface timeTickMsgStream msgstream.MsgStream @@ -64,15 +62,9 @@ func (stNode *serviceTimeNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { } // update service time - var id UniqueID - if stNode.loadType == loadTypePartition { - id = stNode.partitionID - } else { - id = stNode.collectionID - } - err := stNode.tSafeReplica.setTSafe(stNode.vChannel, id, serviceTimeMsg.timeRange.timestampMax) + err := stNode.tSafeReplica.setTSafe(stNode.vChannel, serviceTimeMsg.timeRange.timestampMax) if err != nil { - log.Warn(err.Error()) + log.Error(err.Error()) } //p, _ := tsoutil.ParseTS(serviceTimeMsg.timeRange.timestampMax) //log.Debug("update tSafe:", @@ -114,8 +106,6 @@ func (stNode *serviceTimeNode) Operate(in []flowgraph.Msg) []flowgraph.Msg { func newServiceTimeNode(ctx context.Context, tSafeReplica TSafeReplicaInterface, loadType loadType, - collectionID UniqueID, - partitionID UniqueID, channel Channel, factory msgstream.Factory) *serviceTimeNode { @@ -139,8 +129,6 @@ func newServiceTimeNode(ctx context.Context, return &serviceTimeNode{ baseNode: baseNode, loadType: loadType, - collectionID: collectionID, - partitionID: partitionID, vChannel: channel, tSafeReplica: tSafeReplica, timeTickMsgStream: timeTimeMsgStream, diff --git a/internal/querynode/flow_graph_service_time_node_test.go b/internal/querynode/flow_graph_service_time_node_test.go index 681fa4af5f..a55250bed4 100644 --- a/internal/querynode/flow_graph_service_time_node_test.go +++ b/internal/querynode/flow_graph_service_time_node_test.go @@ -30,7 +30,7 @@ func TestServiceTimeNode_Operate(t *testing.T) { defer cancel() genServiceTimeNode := func() *serviceTimeNode { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() tSafe.addTSafe(defaultVChannel) fac, err := genFactory() @@ -39,8 +39,6 @@ func TestServiceTimeNode_Operate(t *testing.T) { node := newServiceTimeNode(ctx, tSafe, loadTypeCollection, - defaultCollectionID, - defaultPartitionID, defaultVChannel, fac) return node @@ -85,8 +83,7 @@ func TestServiceTimeNode_Operate(t *testing.T) { t.Run("test no tSafe", func(t *testing.T) { node := genServiceTimeNode() - isRemoved := node.tSafeReplica.removeTSafe(defaultVChannel) - assert.True(t, isRemoved) + node.tSafeReplica.removeTSafe(defaultVChannel) msg := &serviceTimeMsg{ timeRange: TimeRange{ timestampMin: 0, diff --git a/internal/querynode/historical_test.go b/internal/querynode/historical_test.go index 1896526b70..af4dc70fa0 100644 --- a/internal/querynode/historical_test.go +++ b/internal/querynode/historical_test.go @@ -100,7 +100,7 @@ func TestHistorical_Search(t *testing.T) { defer cancel() t.Run("test search", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -112,7 +112,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no collection - search partitions", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -127,7 +127,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no collection - search all collection", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -142,7 +142,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test load partition and partition has been released", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -161,7 +161,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no partition in collection", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -178,7 +178,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test load collection partition released in collection", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 37b9f64dee..9f66e53f7f 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -1308,7 +1308,7 @@ func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) { node.etcdKV = etcdKV - node.tSafeReplica = newTSafeReplica(ctx) + node.tSafeReplica = newTSafeReplica() streaming, err := genSimpleStreaming(ctx, node.tSafeReplica) if err != nil { diff --git a/internal/querynode/plan_test.go b/internal/querynode/plan_test.go index b3736dc027..46591c33b6 100644 --- a/internal/querynode/plan_test.go +++ b/internal/querynode/plan_test.go @@ -52,7 +52,7 @@ func TestPlan_createSearchPlanByExpr(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() historical, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/query_collection_test.go b/internal/querynode/query_collection_test.go index f3d41b41eb..7fb35894c4 100644 --- a/internal/querynode/query_collection_test.go +++ b/internal/querynode/query_collection_test.go @@ -45,7 +45,7 @@ import ( ) func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*queryCollection, error) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() historical, err := genSimpleHistorical(ctx, tSafe) if err != nil { return nil, err @@ -110,19 +110,16 @@ func genSimpleSealedSegmentsChangeInfoMsg() *msgstream.SealedSegmentsChangeInfoM } } -func updateTSafe(queryCollection *queryCollection, timestamp Timestamp) { +func updateTSafe(queryCollection *queryCollection, timestamp Timestamp) error { // register queryCollection.tSafeWatchers[defaultVChannel] = newTSafeWatcher() queryCollection.tSafeWatchers[defaultHistoricalVChannel] = newTSafeWatcher() - queryCollection.streaming.tSafeReplica.addTSafe(defaultVChannel) - queryCollection.streaming.tSafeReplica.registerTSafeWatcher(defaultVChannel, queryCollection.tSafeWatchers[defaultVChannel]) - queryCollection.historical.tSafeReplica.addTSafe(defaultHistoricalVChannel) - queryCollection.historical.tSafeReplica.registerTSafeWatcher(defaultHistoricalVChannel, queryCollection.tSafeWatchers[defaultHistoricalVChannel]) - queryCollection.addTSafeWatcher(defaultVChannel) - queryCollection.addTSafeWatcher(defaultHistoricalVChannel) - queryCollection.streaming.tSafeReplica.setTSafe(defaultVChannel, defaultCollectionID, timestamp) - queryCollection.historical.tSafeReplica.setTSafe(defaultHistoricalVChannel, defaultCollectionID, timestamp) + err := queryCollection.streaming.tSafeReplica.setTSafe(defaultVChannel, timestamp) + if err != nil { + return err + } + return queryCollection.historical.tSafeReplica.setTSafe(defaultHistoricalVChannel, timestamp) } func TestQueryCollection_withoutVChannel(t *testing.T) { @@ -139,7 +136,7 @@ func TestQueryCollection_withoutVChannel(t *testing.T) { schema := genTestCollectionSchema(0, false, 2) historicalReplica := newCollectionReplica(etcdKV) - tsReplica := newTSafeReplica(ctx) + tsReplica := newTSafeReplica() streamingReplica := newCollectionReplica(etcdKV) historical := newHistorical(context.Background(), historicalReplica, etcdKV, tsReplica) @@ -508,7 +505,8 @@ func TestQueryCollection_waitNewTSafe(t *testing.T) { assert.NoError(t, err) timestamp := Timestamp(1000) - updateTSafe(queryCollection, timestamp) + err = updateTSafe(queryCollection, timestamp) + assert.NoError(t, err) resTimestamp, err := queryCollection.waitNewTSafe() assert.NoError(t, err) diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 9c1137934d..980530cb8f 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -206,7 +206,7 @@ func (node *QueryNode) Init() error { zap.Any("EtcdEndpoints", Params.EtcdEndpoints), zap.Any("MetaRootPath", Params.MetaRootPath), ) - node.tSafeReplica = newTSafeReplica(node.queryNodeLoopCtx) + node.tSafeReplica = newTSafeReplica() streamingReplica := newCollectionReplica(node.etcdKV) historicalReplica := newCollectionReplica(node.etcdKV) diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index 920963cf82..470eb2b7cc 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -197,7 +197,7 @@ func newQueryNodeMock() *QueryNode { panic(err) } svr := NewQueryNode(ctx, msFactory) - tsReplica := newTSafeReplica(ctx) + tsReplica := newTSafeReplica() streamingReplica := newCollectionReplica(etcdKV) historicalReplica := newCollectionReplica(etcdKV) svr.historical = newHistorical(svr.queryNodeLoopCtx, historicalReplica, etcdKV, tsReplica) diff --git a/internal/querynode/query_service_test.go b/internal/querynode/query_service_test.go index aa80c01323..1a2a611af8 100644 --- a/internal/querynode/query_service_test.go +++ b/internal/querynode/query_service_test.go @@ -226,7 +226,7 @@ func TestQueryService_addQueryCollection(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index 935a88c0ca..f3382ea863 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -891,7 +891,7 @@ func TestSegment_indexInfoTest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() h, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -944,7 +944,7 @@ func TestSegment_indexInfoTest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() h, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/streaming_test.go b/internal/querynode/streaming_test.go index d2f35f238a..67a77e2365 100644 --- a/internal/querynode/streaming_test.go +++ b/internal/querynode/streaming_test.go @@ -27,7 +27,7 @@ func TestStreaming_streaming(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -40,7 +40,7 @@ func TestStreaming_search(t *testing.T) { defer cancel() t.Run("test search", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -59,7 +59,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -78,7 +78,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition and loadCollection", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -104,7 +104,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition and loadPartition", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -129,7 +129,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test no partitions in collection", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -151,7 +151,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test search failed", func(t *testing.T) { - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -178,7 +178,7 @@ func TestStreaming_retrieve(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica(ctx) + tSafe := newTSafeReplica() streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() diff --git a/internal/querynode/task.go b/internal/querynode/task.go index 43db567588..c236b6a0d9 100644 --- a/internal/querynode/task.go +++ b/internal/querynode/task.go @@ -843,8 +843,7 @@ func (r *releaseCollectionTask) releaseReplica(replica ReplicaInterface, replica zap.Any("collectionID", r.req.CollectionID), zap.Any("vChannel", channel), ) - // no tSafe in tSafeReplica, don't return error - _ = r.node.tSafeReplica.removeTSafe(channel) + r.node.tSafeReplica.removeTSafe(channel) // queryCollection and Collection would be deleted in releaseCollection, // so we don't need to remove the tSafeWatcher or channel manually. } @@ -856,8 +855,7 @@ func (r *releaseCollectionTask) releaseReplica(replica ReplicaInterface, replica zap.Any("collectionID", r.req.CollectionID), zap.Any("vDeltaChannel", channel), ) - // no tSafe in tSafeReplica, don't return error - _ = r.node.tSafeReplica.removeTSafe(channel) + r.node.tSafeReplica.removeTSafe(channel) // queryCollection and Collection would be deleted in releaseCollection, // so we don't need to remove the tSafeWatcher or channel manually. } @@ -931,26 +929,22 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error { zap.Any("partitionID", id), zap.Any("vChannel", channel), ) - // no tSafe in tSafeReplica, don't return error - isRemoved := r.node.tSafeReplica.removeTSafe(channel) - if isRemoved { - // no tSafe or tSafe has been removed, - // we need to remove the corresponding tSafeWatcher in queryCollection, - // and remove the corresponding channel in collection - qc, err := r.node.queryService.getQueryCollection(r.req.CollectionID) - if err != nil { - return err - } - err = qc.removeTSafeWatcher(channel) - if err != nil { - return err - } - sCol.removeVChannel(channel) - hCol.removeVChannel(channel) + r.node.tSafeReplica.removeTSafe(channel) + // no tSafe or tSafe has been removed, + // we need to remove the corresponding tSafeWatcher in queryCollection, + // and remove the corresponding channel in collection + qc, err := r.node.queryService.getQueryCollection(r.req.CollectionID) + if err != nil { + return err } + err = qc.removeTSafeWatcher(channel) + if err != nil { + return err + } + sCol.removeVChannel(channel) + hCol.removeVChannel(channel) } } - // remove partition from streaming and historical hasPartitionInHistorical := r.node.historical.replica.hasPartition(id) if hasPartitionInHistorical { @@ -986,23 +980,20 @@ func (r *releasePartitionsTask) Execute(ctx context.Context) error { zap.Any("collectionID", r.req.CollectionID), zap.Any("vChannel", channel), ) - // no tSafe in tSafeReplica, don't return error - isRemoved := r.node.tSafeReplica.removeTSafe(channel) - if isRemoved { - // no tSafe or tSafe has been removed, - // we need to remove the corresponding tSafeWatcher in queryCollection, - // and remove the corresponding channel in collection - qc, err := r.node.queryService.getQueryCollection(r.req.CollectionID) - if err != nil { - return err - } - err = qc.removeTSafeWatcher(channel) - if err != nil { - return err - } - sCol.removeVDeltaChannel(channel) - hCol.removeVDeltaChannel(channel) + r.node.tSafeReplica.removeTSafe(channel) + // no tSafe or tSafe has been removed, + // we need to remove the corresponding tSafeWatcher in queryCollection, + // and remove the corresponding channel in collection + qc, err := r.node.queryService.getQueryCollection(r.req.CollectionID) + if err != nil { + return err } + err = qc.removeTSafeWatcher(channel) + if err != nil { + return err + } + sCol.removeVDeltaChannel(channel) + hCol.removeVDeltaChannel(channel) } } diff --git a/internal/querynode/tsafe.go b/internal/querynode/tsafe.go index 14d087b6cd..27a20b0b4e 100644 --- a/internal/querynode/tsafe.go +++ b/internal/querynode/tsafe.go @@ -17,14 +17,10 @@ package querynode import ( - "context" "errors" - "math" + "fmt" "sync" - "go.uber.org/zap" - - "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/util/typeutil" ) @@ -54,132 +50,27 @@ func (watcher *tSafeWatcher) close() { watcher.closeCh <- struct{}{} } -type tSafer interface { - get() Timestamp - set(id UniqueID, t Timestamp) - registerTSafeWatcher(t *tSafeWatcher) error - start() - close() - removeRecord(partitionID UniqueID) -} - -type tSafeMsg struct { - t Timestamp - id UniqueID // collectionID or partitionID -} - type tSafe struct { - ctx context.Context - cancel context.CancelFunc - channel Channel - tSafeMu sync.Mutex // guards all fields - tSafe Timestamp - watcherList []*tSafeWatcher - tSafeChan chan tSafeMsg - tSafeRecord map[UniqueID]Timestamp - isClose bool + channel Channel + tSafeMu sync.Mutex // guards all fields + tSafe Timestamp + watcher *tSafeWatcher } -func newTSafe(ctx context.Context, channel Channel) tSafer { - ctx1, cancel := context.WithCancel(ctx) - const channelSize = 4096 - - var t tSafer = &tSafe{ - ctx: ctx1, - cancel: cancel, - channel: channel, - watcherList: make([]*tSafeWatcher, 0), - tSafeChan: make(chan tSafeMsg, channelSize), - tSafeRecord: make(map[UniqueID]Timestamp), - tSafe: typeutil.ZeroTimestamp, - } - return t -} - -func (ts *tSafe) start() { - go func() { - for { - select { - case <-ts.ctx.Done(): - ts.tSafeMu.Lock() - ts.isClose = true - log.Debug("tSafe context done", - zap.Any("channel", ts.channel), - ) - for _, watcher := range ts.watcherList { - close(watcher.notifyChan) - } - ts.watcherList = nil - close(ts.tSafeChan) - ts.tSafeMu.Unlock() - return - case m, ok := <-ts.tSafeChan: - if !ok { - // should not happen!! - return - } - ts.tSafeMu.Lock() - ts.tSafeRecord[m.id] = m.t - var tmpT Timestamp = math.MaxUint64 - for _, t := range ts.tSafeRecord { - if t <= tmpT { - tmpT = t - } - } - ts.tSafe = tmpT - for _, watcher := range ts.watcherList { - watcher.notify() - } - - //log.Debug("set tSafe done", - // zap.Any("id", m.id), - // zap.Any("channel", ts.channel), - // zap.Any("t", m.t), - // zap.Any("tSafe", ts.tSafe)) - ts.tSafeMu.Unlock() - } - } - }() -} - -// removeRecord for deleting the old partition which has been released, -// if we don't delete this, tSafe would always be the old partition's timestamp -// (because we set tSafe to the minimum timestamp) from old partition -// flow graph which has been closed and would not update tSafe any more. -// removeRecord should be called when flow graph is been removed. -func (ts *tSafe) removeRecord(partitionID UniqueID) { - ts.tSafeMu.Lock() - defer ts.tSafeMu.Unlock() - if ts.isClose { - // should not happen if tsafe_replica guard correctly - log.Warn("Try to remove record with tsafe close ", - zap.Any("channel", ts.channel), - zap.Any("id", partitionID)) - return - } - log.Debug("remove tSafeRecord", - zap.Any("partitionID", partitionID), - ) - delete(ts.tSafeRecord, partitionID) - var tmpT Timestamp = math.MaxUint64 - for _, t := range ts.tSafeRecord { - if t <= tmpT { - tmpT = t - } - } - ts.tSafe = tmpT - for _, watcher := range ts.watcherList { - watcher.notify() +func newTSafe(channel Channel) *tSafe { + return &tSafe{ + channel: channel, + tSafe: typeutil.ZeroTimestamp, } } func (ts *tSafe) registerTSafeWatcher(t *tSafeWatcher) error { ts.tSafeMu.Lock() - if ts.isClose { - return errors.New("Failed to register tsafe watcher because tsafe is closed " + ts.channel) - } defer ts.tSafeMu.Unlock() - ts.watcherList = append(ts.watcherList, t) + if ts.watcher != nil { + return errors.New(fmt.Sprintln("tSafeWatcher has been existed, channel = ", ts.channel)) + } + ts.watcher = t return nil } @@ -189,23 +80,14 @@ func (ts *tSafe) get() Timestamp { return ts.tSafe } -func (ts *tSafe) set(id UniqueID, t Timestamp) { +func (ts *tSafe) set(t Timestamp) { ts.tSafeMu.Lock() defer ts.tSafeMu.Unlock() - if ts.isClose { - // should not happen if tsafe_replica guard correctly - log.Warn("Try to set id with tsafe close ", - zap.Any("channel", ts.channel), - zap.Any("id", id)) - return + ts.tSafe = t + if ts.watcher != nil { + ts.watcher.notify() } - msg := tSafeMsg{ - t: t, - id: id, - } - ts.tSafeChan <- msg -} - -func (ts *tSafe) close() { - ts.cancel() + //log.Debug("set tSafe done", + // zap.Any("channel", ts.channel), + // zap.Any("t", m.t)) } diff --git a/internal/querynode/tsafe_replica.go b/internal/querynode/tsafe_replica.go index 227d2d9815..36b6a0e17f 100644 --- a/internal/querynode/tsafe_replica.go +++ b/internal/querynode/tsafe_replica.go @@ -17,7 +17,6 @@ package querynode import ( - "context" "errors" "sync" @@ -29,29 +28,22 @@ import ( // TSafeReplicaInterface is the interface wrapper of tSafeReplica type TSafeReplicaInterface interface { getTSafe(vChannel Channel) (Timestamp, error) - setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error + setTSafe(vChannel Channel, timestamp Timestamp) error addTSafe(vChannel Channel) - removeTSafe(vChannel Channel) bool + removeTSafe(vChannel Channel) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error - removeRecord(vChannel Channel, partitionID UniqueID) error -} - -type tSafeRef struct { - tSafer tSafer - ref int } // tSafeReplica implements `TSafeReplicaInterface` interface. type tSafeReplica struct { - mu sync.Mutex // guards tSafes - tSafes map[Channel]*tSafeRef // map[vChannel]tSafeRef - ctx context.Context + mu sync.Mutex // guards tSafes + tSafes map[Channel]*tSafe // map[DMLChannel|deltaChannel]*tSafe } func (t *tSafeReplica) getTSafe(vChannel Channel) (Timestamp, error) { t.mu.Lock() defer t.mu.Unlock() - safer, err := t.getTSaferPrivate(vChannel) + ts, err := t.getTSafePrivate(vChannel) if err != nil { //log.Warn("get tSafe failed", // zap.Any("channel", vChannel), @@ -59,105 +51,67 @@ func (t *tSafeReplica) getTSafe(vChannel Channel) (Timestamp, error) { //) return 0, err } - return safer.get(), nil + return ts.get(), nil } -func (t *tSafeReplica) setTSafe(vChannel Channel, id UniqueID, timestamp Timestamp) error { +func (t *tSafeReplica) setTSafe(vChannel Channel, timestamp Timestamp) error { t.mu.Lock() defer t.mu.Unlock() - safer, err := t.getTSaferPrivate(vChannel) + ts, err := t.getTSafePrivate(vChannel) if err != nil { - //log.Warn("set tSafe failed", zap.Error(err)) - return err + return errors.New("set tSafe failed, err = " + err.Error()) } - safer.set(id, timestamp) + ts.set(timestamp) return nil } -func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) { +func (t *tSafeReplica) getTSafePrivate(vChannel Channel) (*tSafe, error) { if _, ok := t.tSafes[vChannel]; !ok { err := errors.New("cannot found tSafer, vChannel = " + vChannel) //log.Warn(err.Error()) return nil, err } - return t.tSafes[vChannel].tSafer, nil + return t.tSafes[vChannel], nil } func (t *tSafeReplica) addTSafe(vChannel Channel) { t.mu.Lock() defer t.mu.Unlock() if _, ok := t.tSafes[vChannel]; !ok { - t.tSafes[vChannel] = &tSafeRef{ - tSafer: newTSafe(t.ctx, vChannel), - ref: 1, - } - t.tSafes[vChannel].tSafer.start() + t.tSafes[vChannel] = newTSafe(vChannel) log.Debug("add tSafe done", zap.Any("channel", vChannel), - zap.Any("count", t.tSafes[vChannel].ref), ) } else { - t.tSafes[vChannel].ref++ log.Debug("tSafe has been existed", zap.Any("channel", vChannel), - zap.Any("count", t.tSafes[vChannel].ref), ) } } -func (t *tSafeReplica) removeTSafe(vChannel Channel) bool { +func (t *tSafeReplica) removeTSafe(vChannel Channel) { t.mu.Lock() defer t.mu.Unlock() - if _, ok := t.tSafes[vChannel]; !ok { - return false - } - isRemoved := false - t.tSafes[vChannel].ref-- - log.Debug("reduce tSafe reference count", + + log.Debug("remove tSafe replica", zap.Any("vChannel", vChannel), - zap.Any("count", t.tSafes[vChannel].ref), ) - if t.tSafes[vChannel].ref == 0 { - safer, err := t.getTSaferPrivate(vChannel) - if err != nil { - log.Warn(err.Error()) - return false - } - log.Debug("remove tSafe replica", - zap.Any("vChannel", vChannel), - ) - safer.close() - delete(t.tSafes, vChannel) - isRemoved = true - } - return isRemoved -} - -func (t *tSafeReplica) removeRecord(vChannel Channel, partitionID UniqueID) error { - t.mu.Lock() - defer t.mu.Unlock() - safer, err := t.getTSaferPrivate(vChannel) - if err != nil { - return err - } - safer.removeRecord(partitionID) - return nil + delete(t.tSafes, vChannel) } func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatcher) error { t.mu.Lock() defer t.mu.Unlock() - safer, err := t.getTSaferPrivate(vChannel) + ts, err := t.getTSafePrivate(vChannel) if err != nil { return err } - return safer.registerTSafeWatcher(watcher) + return ts.registerTSafeWatcher(watcher) } -func newTSafeReplica(ctx context.Context) TSafeReplicaInterface { +func newTSafeReplica() TSafeReplicaInterface { var replica TSafeReplicaInterface = &tSafeReplica{ - tSafes: make(map[string]*tSafeRef), - ctx: ctx, + tSafes: make(map[string]*tSafe), } return replica } diff --git a/internal/querynode/tsafe_replica_test.go b/internal/querynode/tsafe_replica_test.go index 7d922c68a6..9156398f96 100644 --- a/internal/querynode/tsafe_replica_test.go +++ b/internal/querynode/tsafe_replica_test.go @@ -17,51 +17,44 @@ package querynode import ( - "context" "testing" "github.com/stretchr/testify/assert" ) -func TestTSafeReplica_valid(t *testing.T) { - replica := newTSafeReplica(context.Background()) - replica.addTSafe(defaultVChannel) +func TestTSafeReplica(t *testing.T) { + t.Run("test valid", func(t *testing.T) { + replica := newTSafeReplica() + replica.addTSafe(defaultVChannel) + watcher := newTSafeWatcher() + assert.NotNil(t, watcher) - watcher := newTSafeWatcher() - err := replica.registerTSafeWatcher(defaultVChannel, watcher) - assert.NoError(t, err) + err := replica.registerTSafeWatcher(defaultVChannel, watcher) + assert.NoError(t, err) - timestamp := Timestamp(1000) - err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp) - assert.NoError(t, err) - <-watcher.watcherChan() - resT, err := replica.getTSafe(defaultVChannel) - assert.NoError(t, err) - assert.Equal(t, timestamp, resT) + timestamp := Timestamp(1000) + err = replica.setTSafe(defaultVChannel, timestamp) + assert.NoError(t, err) - isRemoved := replica.removeTSafe(defaultVChannel) - assert.True(t, isRemoved) -} - -func TestTSafeReplica_invalid(t *testing.T) { - replica := newTSafeReplica(context.Background()) - replica.addTSafe(defaultVChannel) - - watcher := newTSafeWatcher() - err := replica.registerTSafeWatcher(defaultVChannel, watcher) - assert.NoError(t, err) - - timestamp := Timestamp(1000) - err = replica.setTSafe(defaultVChannel, defaultCollectionID, timestamp) - assert.NoError(t, err) - <-watcher.watcherChan() - resT, err := replica.getTSafe(defaultVChannel) - assert.NoError(t, err) - assert.Equal(t, timestamp, resT) - - isRemoved := replica.removeTSafe(defaultVChannel) - assert.True(t, isRemoved) - - replica.addTSafe(defaultVChannel) - replica.addTSafe(defaultVChannel) + resT, err := replica.getTSafe(defaultVChannel) + assert.NoError(t, err) + assert.Equal(t, timestamp, resT) + + replica.removeTSafe(defaultVChannel) + _, err = replica.getTSafe(defaultVChannel) + assert.Error(t, err) + }) + + t.Run("test invalid", func(t *testing.T) { + replica := newTSafeReplica() + + err := replica.registerTSafeWatcher(defaultVChannel, nil) + assert.Error(t, err) + + _, err = replica.getTSafe(defaultVChannel) + assert.Error(t, err) + + err = replica.setTSafe(defaultVChannel, Timestamp(1000)) + assert.Error(t, err) + }) } diff --git a/internal/querynode/tsafe_test.go b/internal/querynode/tsafe_test.go index 8f2c4e7222..1d3fcd4090 100644 --- a/internal/querynode/tsafe_test.go +++ b/internal/querynode/tsafe_test.go @@ -17,79 +17,49 @@ package querynode import ( - "context" "sync" "testing" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/util/typeutil" ) -func TestTSafe_GetAndSet(t *testing.T) { - tSafe := newTSafe(context.Background(), "TestTSafe-channel") - tSafe.start() +func TestTSafe_TSafeWatcher(t *testing.T) { watcher := newTSafeWatcher() defer watcher.close() - err := tSafe.registerTSafeWatcher(watcher) - assert.NoError(t, err) + assert.NotNil(t, watcher) var wg sync.WaitGroup wg.Add(1) go func() { - // wait work - <-watcher.watcherChan() - timestamp := tSafe.get() - assert.Equal(t, timestamp, Timestamp(1000)) + watcher.notify() wg.Done() }() - tSafe.set(UniqueID(1), Timestamp(1000)) wg.Wait() -} - -func TestTSafe_Remove(t *testing.T) { - tSafe := newTSafe(context.Background(), "TestTSafe-remove") - tSafe.start() - watcher := newTSafeWatcher() - defer watcher.close() - err := tSafe.registerTSafeWatcher(watcher) - assert.NoError(t, err) - - tSafe.set(UniqueID(1), Timestamp(1000)) - tSafe.set(UniqueID(2), Timestamp(1001)) + // wait notify, expect non-block here <-watcher.watcherChan() - timestamp := tSafe.get() - assert.Equal(t, timestamp, Timestamp(1000)) - - tSafe.removeRecord(UniqueID(1)) - timestamp = tSafe.get() - assert.Equal(t, timestamp, Timestamp(1001)) } -func TestTSafe_Close(t *testing.T) { - tSafe := newTSafe(context.Background(), "TestTSafe-close") - tSafe.start() +func TestTSafe_TSafe(t *testing.T) { + safe := newTSafe("TestTSafe-channel") + assert.NotNil(t, safe) + + timestamp := safe.get() + assert.Equal(t, typeutil.ZeroTimestamp, timestamp) + watcher := newTSafeWatcher() defer watcher.close() - err := tSafe.registerTSafeWatcher(watcher) + assert.NotNil(t, watcher) + + err := safe.registerTSafeWatcher(watcher) + assert.NotNil(t, safe.watcher) assert.NoError(t, err) - // test set won't panic while close - go func() { - for i := 0; i <= 100; i++ { - tSafe.set(UniqueID(i), Timestamp(1000)) - } - }() + targetTimestamp := Timestamp(1000) + safe.set(targetTimestamp) - tSafe.close() - - // wait until channel close - for range watcher.watcherChan() { - - } - - tSafe.set(UniqueID(101), Timestamp(1000)) - tSafe.removeRecord(UniqueID(1)) - // register TSafe will fail - err = tSafe.registerTSafeWatcher(watcher) - assert.Error(t, err) + timestamp = safe.get() + assert.Equal(t, targetTimestamp, timestamp) }