diff --git a/internal/util/flowgraph/flow_graph.go b/internal/util/flowgraph/flow_graph.go index f2ad3cd580..6b67e0a32d 100644 --- a/internal/util/flowgraph/flow_graph.go +++ b/internal/util/flowgraph/flow_graph.go @@ -30,6 +30,7 @@ import ( // TimeTickedFlowGraph flowgraph with input from tt msg stream type TimeTickedFlowGraph struct { nodeCtx map[NodeName]*nodeCtx + nodeSequence []NodeName nodeCtxManager *nodeCtxManager stopOnce sync.Once startOnce sync.Once @@ -46,6 +47,7 @@ func (fg *TimeTickedFlowGraph) AddNode(node Node) { if node.IsInputNode() { fg.nodeCtxManager = NewNodeCtxManager(&nodeCtx, fg.closeWg) } + fg.nodeSequence = append(fg.nodeSequence, node.Name()) } // SetEdges set directed edges from in nodes to out nodes @@ -88,14 +90,16 @@ func (fg *TimeTickedFlowGraph) Start() { } func (fg *TimeTickedFlowGraph) Blockall() { - for _, v := range fg.nodeCtx { - v.Block() + // Lock with determined order to avoid deadlock. + for _, nodeName := range fg.nodeSequence { + fg.nodeCtx[nodeName].Block() } } func (fg *TimeTickedFlowGraph) Unblock() { - for _, v := range fg.nodeCtx { - v.Unblock() + // Unlock with reverse order. + for i := len(fg.nodeSequence) - 1; i >= 0; i-- { + fg.nodeCtx[fg.nodeSequence[i]].Unblock() } } diff --git a/internal/util/flowgraph/flow_graph_test.go b/internal/util/flowgraph/flow_graph_test.go index a745b72597..cb867fbf0f 100644 --- a/internal/util/flowgraph/flow_graph_test.go +++ b/internal/util/flowgraph/flow_graph_test.go @@ -192,8 +192,12 @@ func TestTimeTickedFlowGraph_AddNode(t *testing.T) { fg.AddNode(a) assert.Equal(t, len(fg.nodeCtx), 1) + assert.Equal(t, len(fg.nodeSequence), 1) + assert.Equal(t, a.Name(), fg.nodeSequence[0]) fg.AddNode(b) assert.Equal(t, len(fg.nodeCtx), 2) + assert.Equal(t, len(fg.nodeSequence), 2) + assert.Equal(t, b.Name(), fg.nodeSequence[1]) } func TestTimeTickedFlowGraph_Start(t *testing.T) { @@ -223,3 +227,30 @@ func TestTimeTickedFlowGraph_Close(t *testing.T) { defer cancel() fg.Close() } + +func TestBlockAll(t *testing.T) { + fg := NewTimeTickedFlowGraph(context.Background()) + fg.AddNode(&nodeA{}) + fg.AddNode(&nodeB{}) + fg.AddNode(&nodeC{}) + + count := 1000 + ch := make([]chan struct{}, count) + for i := 0; i < count; i++ { + ch[i] = make(chan struct{}) + go func(i int) { + fg.Blockall() + defer fg.Unblock() + close(ch[i]) + }(i) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + for i := 0; i < count; i++ { + select { + case <-ch[i]: + case <-ctx.Done(): + t.Error("block all timeout") + } + } +}