From 07cc449fbf2aa6317128d18401b6e496143bc861 Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Tue, 3 Aug 2021 22:43:25 +0800 Subject: [PATCH] Fix data race in flow graph (#6946) * Fix data race in flow graph Signed-off-by: bigsheeper * add cancel func to flowgraph Signed-off-by: bigsheeper --- internal/util/flowgraph/flow_graph.go | 6 +++++- internal/util/flowgraph/input_node.go | 2 +- internal/util/flowgraph/node.go | 25 +++++++++++++++++++------ 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/internal/util/flowgraph/flow_graph.go b/internal/util/flowgraph/flow_graph.go index 84a0c42c19..ef697fc021 100644 --- a/internal/util/flowgraph/flow_graph.go +++ b/internal/util/flowgraph/flow_graph.go @@ -20,6 +20,7 @@ import ( type TimeTickedFlowGraph struct { ctx context.Context + cancel context.CancelFunc nodeCtx map[NodeName]*nodeCtx } @@ -89,11 +90,14 @@ func (fg *TimeTickedFlowGraph) Close() { // } v.Close() } + fg.cancel() } func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph { + ctx1, cancel := context.WithCancel(ctx) flowGraph := TimeTickedFlowGraph{ - ctx: ctx, + ctx: ctx1, + cancel: cancel, nodeCtx: make(map[string]*nodeCtx), } diff --git a/internal/util/flowgraph/input_node.go b/internal/util/flowgraph/input_node.go index 9831a5e6e0..a547bf5b36 100644 --- a/internal/util/flowgraph/input_node.go +++ b/internal/util/flowgraph/input_node.go @@ -29,7 +29,7 @@ func (inNode *InputNode) IsInputNode() bool { } func (inNode *InputNode) Close() { - (*inNode.inStream).Close() + // do nothing } func (inNode *InputNode) Name() string { diff --git a/internal/util/flowgraph/node.go b/internal/util/flowgraph/node.go index 07f18d2cce..7d0725f51e 100644 --- a/internal/util/flowgraph/node.go +++ b/internal/util/flowgraph/node.go @@ -14,9 +14,12 @@ package flowgraph import ( "context" "fmt" - "log" "sync" "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/log" ) type Node interface { @@ -49,7 +52,7 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) { // fmt.Println("start InputNode.inStream") inStream, ok := nodeCtx.node.(*InputNode) if !ok { - log.Fatal("Invalid inputNode") + log.Error("Invalid inputNode") } (*inStream.inStream).Start() } @@ -57,6 +60,16 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) { for { select { case <-ctx.Done(): + if nodeCtx.node.IsInputNode() { + inStream, ok := nodeCtx.node.(*InputNode) + if !ok { + log.Error("Invalid inputNode") + } + (*inStream.inStream).Close() + log.Debug("message stream closed", + zap.Any("node name", inStream.name), + ) + } wg.Done() //fmt.Println(nodeCtx.node.Name(), "closed") return @@ -74,7 +87,7 @@ func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) { downstreamLength := len(nodeCtx.downstreamInputChanIdx) if len(nodeCtx.downstream) < downstreamLength { - log.Println("nodeCtx.downstream length = ", len(nodeCtx.downstream)) + log.Warn("", zap.Any("nodeCtx.downstream length", len(nodeCtx.downstream))) } if len(res) < downstreamLength { // log.Println("node result length = ", len(res)) @@ -104,7 +117,7 @@ func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int defer func() { err := recover() if err != nil { - log.Println(err) + log.Warn(fmt.Sprintln(err)) } }() nodeCtx.inputChannels[inputChanIdx] <- msg @@ -126,7 +139,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) { case msg, ok := <-channel: if !ok { // TODO: add status - log.Println("input channel closed") + log.Warn("input channel closed") return } nodeCtx.inputMessages[i] = msg @@ -155,7 +168,7 @@ func (nodeCtx *nodeCtx) collectInputMessages(exitCtx context.Context) { return case msg, ok := <-channel: if !ok { - log.Println("input channel closed") + log.Warn("input channel closed") return } nodeCtx.inputMessages[i] = msg