Improve the close method in the graph (#19100)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
SimFG 2022-09-09 10:00:36 +08:00 committed by GitHub
parent 3927ae9952
commit ceea04c274
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 86 deletions

View File

@ -89,13 +89,6 @@ func (fg *TimeTickedFlowGraph) Close() {
fg.stopOnce.Do(func() { fg.stopOnce.Do(func() {
for _, v := range fg.nodeCtx { for _, v := range fg.nodeCtx {
if v.node.IsInputNode() { if v.node.IsInputNode() {
// close inputNode first
v.Close()
}
}
for _, v := range fg.nodeCtx {
if !v.node.IsInputNode() {
// close other nodes
v.Close() v.Close()
} }
} }

View File

@ -30,6 +30,7 @@ type InputNode struct {
BaseNode BaseNode
inStream msgstream.MsgStream inStream msgstream.MsgStream
name string name string
closeMsgChan chan struct{}
} }
// IsInputNode returns whether Node is InputNode // IsInputNode returns whether Node is InputNode
@ -44,11 +45,16 @@ func (inNode *InputNode) Start() {
// Close implements node // Close implements node
func (inNode *InputNode) Close() { func (inNode *InputNode) Close() {
inNode.inStream.Close() select {
case <-inNode.closeMsgChan:
return
default:
close(inNode.closeMsgChan)
log.Debug("message stream closed", log.Debug("message stream closed",
zap.String("node name", inNode.name), zap.String("node name", inNode.name),
) )
} }
}
// Name returns node name // Name returns node name
func (inNode *InputNode) Name() string { func (inNode *InputNode) Name() string {
@ -62,7 +68,13 @@ func (inNode *InputNode) InStream() msgstream.MsgStream {
// Operate consume a message pack from msgstream and return // Operate consume a message pack from msgstream and return
func (inNode *InputNode) Operate(in []Msg) []Msg { func (inNode *InputNode) Operate(in []Msg) []Msg {
msgPack, ok := <-inNode.inStream.Chan() select {
case <-inNode.closeMsgChan:
inNode.inStream.Close()
return []Msg{&MsgStreamMsg{
isCloseMsg: true,
}}
case msgPack, ok := <-inNode.inStream.Chan():
if !ok { if !ok {
log.Warn("MsgStream closed", zap.Any("input node", inNode.Name())) log.Warn("MsgStream closed", zap.Any("input node", inNode.Name()))
return []Msg{} return []Msg{}
@ -94,6 +106,7 @@ func (inNode *InputNode) Operate(in []Msg) []Msg {
return []Msg{msgStreamMsg} return []Msg{msgStreamMsg}
} }
}
// NewInputNode composes an InputNode with provided MsgStream, name and parameters // NewInputNode composes an InputNode with provided MsgStream, name and parameters
func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32) *InputNode { func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength int32, maxParallelism int32) *InputNode {
@ -105,5 +118,6 @@ func NewInputNode(inStream msgstream.MsgStream, nodeName string, maxQueueLength
BaseNode: baseNode, BaseNode: baseNode,
inStream: inStream, inStream: inStream,
name: nodeName, name: nodeName,
closeMsgChan: make(chan struct{}),
} }
} }

View File

@ -41,10 +41,7 @@ func TestInputNode(t *testing.T) {
produceStream.Produce(&msgPack) produceStream.Produce(&msgPack)
nodeName := "input_node" nodeName := "input_node"
inputNode := &InputNode{ inputNode := NewInputNode(msgStream, nodeName, 100, 100)
inStream: msgStream,
name: nodeName,
}
defer inputNode.Close() defer inputNode.Close()
isInputNode := inputNode.IsInputNode() isInputNode := inputNode.IsInputNode()

View File

@ -32,6 +32,7 @@ type MsgStreamMsg struct {
timestampMax Timestamp timestampMax Timestamp
startPositions []*MsgPosition startPositions []*MsgPosition
endPositions []*MsgPosition endPositions []*MsgPosition
isCloseMsg bool
} }
// GenerateMsgStreamMsg is used to create a new MsgStreamMsg object // GenerateMsgStreamMsg is used to create a new MsgStreamMsg object

View File

@ -59,23 +59,28 @@ type nodeCtx struct {
downstreamInputChanIdx map[string]int downstreamInputChanIdx map[string]int
closeCh chan struct{} // notify work to exit closeCh chan struct{} // notify work to exit
closeWg sync.WaitGroup // block Close until work exit
} }
// Start invoke Node `Start` method and start a worker goroutine // Start invoke Node `Start` method and start a worker goroutine
func (nodeCtx *nodeCtx) Start() { func (nodeCtx *nodeCtx) Start() {
nodeCtx.node.Start() nodeCtx.node.Start()
nodeCtx.closeWg.Add(1)
go nodeCtx.work() go nodeCtx.work()
} }
func isCloseMsg(msgs []Msg) bool {
if len(msgs) == 1 {
msg, ok := msgs[0].(*MsgStreamMsg)
return ok && msg.isCloseMsg
}
return false
}
// work handles node work spinning // work handles node work spinning
// 1. collectMessage from upstream or just produce Msg from InputNode // 1. collectMessage from upstream or just produce Msg from InputNode
// 2. invoke node.Operate // 2. invoke node.Operate
// 3. deliver the Operate result to downstream nodes // 3. deliver the Operate result to downstream nodes
func (nodeCtx *nodeCtx) work() { func (nodeCtx *nodeCtx) work() {
defer nodeCtx.closeWg.Done()
name := fmt.Sprintf("nodeCtxTtChecker-%s", nodeCtx.node.Name()) name := fmt.Sprintf("nodeCtxTtChecker-%s", nodeCtx.node.Name())
var checker *timerecord.GroupChecker var checker *timerecord.GroupChecker
if enableTtChecker { if enableTtChecker {
@ -98,8 +103,19 @@ func (nodeCtx *nodeCtx) work() {
nodeCtx.collectInputMessages() nodeCtx.collectInputMessages()
inputs = nodeCtx.inputMessages inputs = nodeCtx.inputMessages
} }
// the input message decides whether the operate method is executed
if isCloseMsg(inputs) {
res = inputs
}
if len(res) == 0 {
n := nodeCtx.node n := nodeCtx.node
res = n.Operate(inputs) res = n.Operate(inputs)
}
// the res decide whether the node should be closed.
if isCloseMsg(res) {
close(nodeCtx.closeCh)
nodeCtx.node.Close()
}
if enableTtChecker { if enableTtChecker {
checker.Check(name) checker.Check(name)
@ -127,13 +143,7 @@ func (nodeCtx *nodeCtx) work() {
// Close handles cleanup logic and notify worker to quit // Close handles cleanup logic and notify worker to quit
func (nodeCtx *nodeCtx) Close() { func (nodeCtx *nodeCtx) Close() {
if nodeCtx.node.IsInputNode() { if nodeCtx.node.IsInputNode() {
nodeCtx.node.Close() // close input msgStream nodeCtx.node.Close()
close(nodeCtx.closeCh)
nodeCtx.closeWg.Wait()
} else {
close(nodeCtx.closeCh)
nodeCtx.closeWg.Wait()
nodeCtx.node.Close() // close output msgStream, and etc...
} }
} }
@ -146,10 +156,7 @@ func (nodeCtx *nodeCtx) deliverMsg(wg *sync.WaitGroup, msg Msg, inputChanIdx int
log.Warn(fmt.Sprintln(err)) log.Warn(fmt.Sprintln(err))
} }
}() }()
select { nodeCtx.inputChannels[inputChanIdx] <- msg
case <-nodeCtx.closeCh:
case nodeCtx.inputChannels[inputChanIdx] <- msg:
}
} }
func (nodeCtx *nodeCtx) collectInputMessages() { func (nodeCtx *nodeCtx) collectInputMessages() {
@ -161,10 +168,7 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
// and move them to inputMessages. // and move them to inputMessages.
for i := 0; i < inputsNum; i++ { for i := 0; i < inputsNum; i++ {
channel := nodeCtx.inputChannels[i] channel := nodeCtx.inputChannels[i]
select { msg, ok := <-channel
case <-nodeCtx.closeCh:
return
case msg, ok := <-channel:
if !ok { if !ok {
// TODO: add status // TODO: add status
log.Warn("input channel closed") log.Warn("input channel closed")
@ -172,7 +176,6 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
} }
nodeCtx.inputMessages[i] = msg nodeCtx.inputMessages[i] = msg
} }
}
// timeTick alignment check // timeTick alignment check
if len(nodeCtx.inputMessages) > 1 { if len(nodeCtx.inputMessages) > 1 {
@ -191,10 +194,7 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
for nodeCtx.inputMessages[i].TimeTick() != latestTime { for nodeCtx.inputMessages[i].TimeTick() != latestTime {
log.Debug("Try to align timestamp", zap.Uint64("t1", latestTime), zap.Uint64("t2", nodeCtx.inputMessages[i].TimeTick())) log.Debug("Try to align timestamp", zap.Uint64("t1", latestTime), zap.Uint64("t2", nodeCtx.inputMessages[i].TimeTick()))
channel := nodeCtx.inputChannels[i] channel := nodeCtx.inputChannels[i]
select { msg, ok := <-channel
case <-nodeCtx.closeCh:
return
case msg, ok := <-channel:
if !ok { if !ok {
log.Warn("input channel closed") log.Warn("input channel closed")
return return
@ -202,7 +202,6 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
nodeCtx.inputMessages[i] = msg nodeCtx.inputMessages[i] = msg
} }
} }
}
sign <- struct{}{} sign <- struct{}{}
}() }()
@ -210,7 +209,6 @@ func (nodeCtx *nodeCtx) collectInputMessages() {
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
panic("Fatal, misaligned time tick, please restart pulsar") panic("Fatal, misaligned time tick, please restart pulsar")
case <-sign: case <-sign:
case <-nodeCtx.closeCh:
} }
} }
} }

View File

@ -73,10 +73,7 @@ func TestNodeCtx_Start(t *testing.T) {
produceStream.Produce(&msgPack) produceStream.Produce(&msgPack)
nodeName := "input_node" nodeName := "input_node"
inputNode := &InputNode{ inputNode := NewInputNode(msgStream, nodeName, 100, 100)
inStream: msgStream,
name: nodeName,
}
node := &nodeCtx{ node := &nodeCtx{
node: inputNode, node: inputNode,