diff --git a/internal/util/flowgraph/flow_graph.go b/internal/util/flowgraph/flow_graph.go index 7651d53a2e..373aab4650 100644 --- a/internal/util/flowgraph/flow_graph.go +++ b/internal/util/flowgraph/flow_graph.go @@ -89,13 +89,6 @@ func (fg *TimeTickedFlowGraph) Close() { fg.stopOnce.Do(func() { for _, v := range fg.nodeCtx { if v.node.IsInputNode() { - // close inputNode first - v.Close() - } - } - for _, v := range fg.nodeCtx { - if !v.node.IsInputNode() { - // close other nodes v.Close() } } diff --git a/internal/util/flowgraph/input_node.go b/internal/util/flowgraph/input_node.go index 861e163d4d..0ad23da735 100644 --- a/internal/util/flowgraph/input_node.go +++ b/internal/util/flowgraph/input_node.go @@ -28,8 +28,9 @@ import ( // InputNode is the entry point of flowgragh type InputNode struct { BaseNode - inStream msgstream.MsgStream - name string + inStream msgstream.MsgStream + name string + closeMsgChan chan struct{} } // IsInputNode returns whether Node is InputNode @@ -44,10 +45,15 @@ func (inNode *InputNode) Start() { // Close implements node func (inNode *InputNode) Close() { - inNode.inStream.Close() - log.Debug("message stream closed", - zap.String("node name", inNode.name), - ) + select { + case <-inNode.closeMsgChan: + return + default: + close(inNode.closeMsgChan) + log.Debug("message stream closed", + zap.String("node name", inNode.name), + ) + } } // Name returns node name @@ -62,37 +68,44 @@ func (inNode *InputNode) InStream() msgstream.MsgStream { // Operate consume a message pack from msgstream and return func (inNode *InputNode) Operate(in []Msg) []Msg { - msgPack, ok := <-inNode.inStream.Chan() - if !ok { - log.Warn("MsgStream closed", zap.Any("input node", inNode.Name())) - return []Msg{} - } + select { + case <-inNode.closeMsgChan: + inNode.inStream.Close() + return []Msg{&MsgStreamMsg{ + isCloseMsg: true, + }} + case msgPack, ok := <-inNode.inStream.Chan(): + if !ok { + log.Warn("MsgStream closed", zap.Any("input node", inNode.Name())) + return []Msg{} + } - // TODO: add status - if msgPack == nil { - return nil - } - var spans []opentracing.Span - for _, msg := range msgPack.Msgs { - sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) - sp.LogFields(oplog.String("input_node name", inNode.Name())) - spans = append(spans, sp) - msg.SetTraceCtx(ctx) - } + // TODO: add status + if msgPack == nil { + return nil + } + var spans []opentracing.Span + for _, msg := range msgPack.Msgs { + sp, ctx := trace.StartSpanFromContext(msg.TraceCtx()) + sp.LogFields(oplog.String("input_node name", inNode.Name())) + spans = append(spans, sp) + msg.SetTraceCtx(ctx) + } - var msgStreamMsg Msg = &MsgStreamMsg{ - tsMessages: msgPack.Msgs, - timestampMin: msgPack.BeginTs, - timestampMax: msgPack.EndTs, - startPositions: msgPack.StartPositions, - endPositions: msgPack.EndPositions, - } + var msgStreamMsg Msg = &MsgStreamMsg{ + tsMessages: msgPack.Msgs, + timestampMin: msgPack.BeginTs, + timestampMax: msgPack.EndTs, + startPositions: msgPack.StartPositions, + endPositions: msgPack.EndPositions, + } - for _, span := range spans { - span.Finish() - } + for _, span := range spans { + span.Finish() + } - return []Msg{msgStreamMsg} + return []Msg{msgStreamMsg} + } } // NewInputNode composes an InputNode with provided MsgStream, name and parameters @@ -102,8 +115,9 @@ func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength baseNode.SetMaxParallelism(maxParallelism) return &InputNode{ - BaseNode: baseNode, - inStream: inStream, - name: nodeName, + BaseNode: baseNode, + inStream: inStream, + name: nodeName, + closeMsgChan: make(chan struct{}), } } diff --git a/internal/util/flowgraph/input_node_test.go b/internal/util/flowgraph/input_node_test.go index 4e289965cb..7812805a4d 100644 --- a/internal/util/flowgraph/input_node_test.go +++ b/internal/util/flowgraph/input_node_test.go @@ -41,10 +41,7 @@ func TestInputNode(t *testing.T) { produceStream.Produce(&msgPack) nodeName := "input_node" - inputNode := &InputNode{ - inStream: msgStream, - name: nodeName, - } + inputNode := NewInputNode(msgStream, nodeName, 100, 100) defer inputNode.Close() isInputNode := inputNode.IsInputNode() diff --git a/internal/util/flowgraph/message.go b/internal/util/flowgraph/message.go index 83e72416ef..f2e8bcd563 100644 --- a/internal/util/flowgraph/message.go +++ b/internal/util/flowgraph/message.go @@ -32,6 +32,7 @@ type MsgStreamMsg struct { timestampMax Timestamp startPositions []*MsgPosition endPositions []*MsgPosition + isCloseMsg bool } // GenerateMsgStreamMsg is used to create a new MsgStreamMsg object diff --git a/internal/util/flowgraph/node.go b/internal/util/flowgraph/node.go index da95c3d00a..29cb8a9a3e 100644 --- a/internal/util/flowgraph/node.go +++ b/internal/util/flowgraph/node.go @@ -58,24 +58,29 @@ type nodeCtx struct { downstream []*nodeCtx downstreamInputChanIdx map[string]int - closeCh chan struct{} // notify work to exit - closeWg sync.WaitGroup // block Close until work exit + closeCh chan struct{} // notify work to exit } // Start invoke Node `Start` method and start a worker goroutine func (nodeCtx *nodeCtx) Start() { nodeCtx.node.Start() - nodeCtx.closeWg.Add(1) go nodeCtx.work() } +func isCloseMsg(msgs []Msg) bool { + if len(msgs) == 1 { + msg, ok := msgs[0].(*MsgStreamMsg) + return ok && msg.isCloseMsg + } + return false +} + // work handles node work spinning // 1. collectMessage from upstream or just produce Msg from InputNode // 2. invoke node.Operate // 3. deliver the Operate result to downstream nodes func (nodeCtx *nodeCtx) work() { - defer nodeCtx.closeWg.Done() name := fmt.Sprintf("nodeCtxTtChecker-%s", nodeCtx.node.Name()) var checker *timerecord.GroupChecker if enableTtChecker { @@ -98,8 +103,19 @@ func (nodeCtx *nodeCtx) work() { nodeCtx.collectInputMessages() inputs = nodeCtx.inputMessages } - n := nodeCtx.node - res = n.Operate(inputs) + // the input message decides whether the operate method is executed + if isCloseMsg(inputs) { + res = inputs + } + if len(res) == 0 { + n := nodeCtx.node + res = n.Operate(inputs) + } + // the res decide whether the node should be closed. + if isCloseMsg(res) { + close(nodeCtx.closeCh) + nodeCtx.node.Close() + } if enableTtChecker { checker.Check(name) @@ -127,13 +143,7 @@ func (nodeCtx *nodeCtx) work() { // Close handles cleanup logic and notify worker to quit func (nodeCtx *nodeCtx) Close() { if nodeCtx.node.IsInputNode() { - nodeCtx.node.Close() // close input msgStream - close(nodeCtx.closeCh) - nodeCtx.closeWg.Wait() - } else { - close(nodeCtx.closeCh) - nodeCtx.closeWg.Wait() - nodeCtx.node.Close() // close output msgStream, and etc... + nodeCtx.node.Close() } } @@ -146,10 +156,7 @@ func (nodeCtx *nodeCtx) deliverMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int log.Warn(fmt.Sprintln(err)) } }() - select { - case <-nodeCtx.closeCh: - case nodeCtx.inputChannels[inputChanIdx] <- msg: - } + nodeCtx.inputChannels[inputChanIdx] <- msg } func (nodeCtx *nodeCtx) collectInputMessages() { @@ -161,17 +168,13 @@ func (nodeCtx *nodeCtx) collectInputMessages() { // and move them to inputMessages. for i := 0; i < inputsNum; i++ { channel := nodeCtx.inputChannels[i] - select { - case <-nodeCtx.closeCh: + msg, ok := <-channel + if !ok { + // TODO: add status + log.Warn("input channel closed") return - case msg, ok := <-channel: - if !ok { - // TODO: add status - log.Warn("input channel closed") - return - } - nodeCtx.inputMessages[i] = msg } + nodeCtx.inputMessages[i] = msg } // timeTick alignment check @@ -191,16 +194,12 @@ func (nodeCtx *nodeCtx) collectInputMessages() { for nodeCtx.inputMessages[i].TimeTick() != latestTime { log.Debug("Try to align timestamp", zap.Uint64("t1", latestTime), zap.Uint64("t2", nodeCtx.inputMessages[i].TimeTick())) channel := nodeCtx.inputChannels[i] - select { - case <-nodeCtx.closeCh: + msg, ok := <-channel + if !ok { + log.Warn("input channel closed") return - case msg, ok := <-channel: - if !ok { - log.Warn("input channel closed") - return - } - nodeCtx.inputMessages[i] = msg } + nodeCtx.inputMessages[i] = msg } } sign <- struct{}{} @@ -210,7 +209,6 @@ func (nodeCtx *nodeCtx) collectInputMessages() { case <-time.After(10 * time.Second): panic("Fatal, misaligned time tick, please restart pulsar") case <-sign: - case <-nodeCtx.closeCh: } } } diff --git a/internal/util/flowgraph/node_test.go b/internal/util/flowgraph/node_test.go index 1fcf2b81cc..e700e3b3e0 100644 --- a/internal/util/flowgraph/node_test.go +++ b/internal/util/flowgraph/node_test.go @@ -73,10 +73,7 @@ func TestNodeCtx_Start(t *testing.T) { produceStream.Produce(&msgPack) nodeName := "input_node" - inputNode := &InputNode{ - inStream: msgStream, - name: nodeName, - } + inputNode := NewInputNode(msgStream, nodeName, 100, 100) node := &nodeCtx{ node: inputNode,