From 8020dc2256b401cc470a891445fe0978ce710d6e Mon Sep 17 00:00:00 2001 From: bigsheeper Date: Mon, 2 Nov 2020 16:44:54 +0800 Subject: [PATCH] Add flow graph Signed-off-by: bigsheeper --- internal/util/flowgraph/flow_graph.go | 97 ++++++++ internal/util/flowgraph/flow_graph_test.go | 246 +++++++++++++++++++++ internal/util/flowgraph/message.go | 6 + internal/util/flowgraph/node.go | 118 ++++++++++ 4 files changed, 467 insertions(+) create mode 100644 internal/util/flowgraph/flow_graph.go create mode 100644 internal/util/flowgraph/flow_graph_test.go create mode 100644 internal/util/flowgraph/message.go create mode 100644 internal/util/flowgraph/node.go diff --git a/internal/util/flowgraph/flow_graph.go b/internal/util/flowgraph/flow_graph.go new file mode 100644 index 0000000000..b775d64682 --- /dev/null +++ b/internal/util/flowgraph/flow_graph.go @@ -0,0 +1,97 @@ +package flowgraph + +import ( + "context" + "github.com/zilliztech/milvus-distributed/internal/errors" + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" + "sync" +) + +type Timestamp = typeutil.Timestamp + +type flowGraphStates struct { + startTick Timestamp + numActiveTasks map[string]int64 + numCompletedTasks map[string]int64 +} + +type TimeTickedFlowGraph struct { + ctx context.Context + states *flowGraphStates + nodeCtx map[string]*nodeCtx +} + +func (fg *TimeTickedFlowGraph) AddNode(node *Node) { + nodeName := (*node).Name() + nodeCtx := nodeCtx{ + node: node, + inputChannels: make([]chan *Msg, 0), + downstreamInputChanIdx: make(map[string]int), + } + fg.nodeCtx[nodeName] = &nodeCtx +} + +func (fg *TimeTickedFlowGraph) SetEdges(nodeName string, in []string, out []string) error { + currentNode, ok := fg.nodeCtx[nodeName] + if !ok { + errMsg := "Cannot find node:" + nodeName + return errors.New(errMsg) + } + + // init current node's downstream + currentNode.downstream = make([]*nodeCtx, len(out)) + + // set in nodes + for i, inNodeName := range in { + inNode, ok := fg.nodeCtx[inNodeName] + if !ok { + errMsg := "Cannot find in node:" + inNodeName + return errors.New(errMsg) + } + inNode.downstreamInputChanIdx[nodeName] = i + } + + // set out nodes + for i, n := range out { + outNode, ok := fg.nodeCtx[n] + if !ok { + errMsg := "Cannot find out node:" + n + return errors.New(errMsg) + } + maxQueueLength := (*outNode.node).MaxQueueLength() + outNode.inputChannels = append(outNode.inputChannels, make(chan *Msg, maxQueueLength)) + currentNode.downstream[i] = outNode + } + + return nil +} + +func (fg *TimeTickedFlowGraph) Start() { + wg := sync.WaitGroup{} + for _, v := range fg.nodeCtx { + wg.Add(1) + go v.Start(fg.ctx, &wg) + } + wg.Wait() +} + +func (fg *TimeTickedFlowGraph) Close() error { + for _, v := range fg.nodeCtx { + v.Close() + } + return nil +} + +func NewTimeTickedFlowGraph(ctx context.Context) *TimeTickedFlowGraph { + flowGraph := TimeTickedFlowGraph{ + ctx: ctx, + states: &flowGraphStates{ + startTick: 0, + numActiveTasks: make(map[string]int64), + numCompletedTasks: make(map[string]int64), + }, + nodeCtx: make(map[string]*nodeCtx), + } + + return &flowGraph +} diff --git a/internal/util/flowgraph/flow_graph_test.go b/internal/util/flowgraph/flow_graph_test.go new file mode 100644 index 0000000000..eebc7ffe18 --- /dev/null +++ b/internal/util/flowgraph/flow_graph_test.go @@ -0,0 +1,246 @@ +package flowgraph + +import ( + "context" + "fmt" + "log" + "math" + "math/rand" + "sync" + "testing" + "time" +) + +const ctxTimeInMillisecond = 3000 + +type nodeA struct { + baseNode + a float64 +} + +type nodeB struct { + baseNode + b float64 +} + +type nodeC struct { + baseNode + c float64 +} + +type nodeD struct { + baseNode + d float64 + resChan chan float64 +} + +type intMsg struct { + num float64 + t Timestamp +} + +func (m *intMsg) TimeTick() Timestamp { + return m.t +} + +func (m *intMsg) DownStreamNodeIdx() int32 { + return 1 +} + +func intMsg2Msg(in []*intMsg) []*Msg { + out := make([]*Msg, 0) + for _, msg := range in { + var m Msg = msg + out = append(out, &m) + } + return out +} + +func msg2IntMsg(in []*Msg) []*intMsg { + out := make([]*intMsg, 0) + for _, msg := range in { + out = append(out, (*msg).(*intMsg)) + } + return out +} + +func (a *nodeA) Name() string { + return "NodeA" +} + +func (a *nodeA) Operate(in []*Msg) []*Msg { + return append(in, in...) +} + +func (b *nodeB) Name() string { + return "NodeB" +} + +func (b *nodeB) Operate(in []*Msg) []*Msg { + messages := make([]*intMsg, 0) + for _, msg := range msg2IntMsg(in) { + messages = append(messages, &intMsg{ + num: math.Pow(msg.num, 2), + }) + } + return intMsg2Msg(messages) +} + +func (c *nodeC) Name() string { + return "NodeC" +} + +func (c *nodeC) Operate(in []*Msg) []*Msg { + messages := make([]*intMsg, 0) + for _, msg := range msg2IntMsg(in) { + messages = append(messages, &intMsg{ + num: math.Sqrt(msg.num), + }) + } + return intMsg2Msg(messages) +} + +func (d *nodeD) Name() string { + return "NodeD" +} + +func (d *nodeD) Operate(in []*Msg) []*Msg { + messages := make([]*intMsg, 0) + outLength := len(in) / 2 + inMessages := msg2IntMsg(in) + for i := 0; i < outLength; i++ { + var msg = &intMsg{ + num: inMessages[i].num + inMessages[i+outLength].num, + } + messages = append(messages, msg) + } + d.d = messages[0].num + d.resChan <- d.d + fmt.Println("flow graph result:", d.d) + return intMsg2Msg(messages) +} + +func sendMsgFromCmd(ctx context.Context, fg *TimeTickedFlowGraph) { + for { + select { + case <-ctx.Done(): + return + default: + time.Sleep(time.Millisecond * time.Duration(500)) + var num = float64(rand.Int() % 100) + var msg Msg = &intMsg{num: num} + a := nodeA{} + fg.nodeCtx[a.Name()].inputChannels[0] <- &msg + fmt.Println("send number", num, "to node", a.Name()) + res, ok := receiveResult(ctx, fg) + if !ok { + return + } + // assert result + if res != math.Pow(num, 2)+math.Sqrt(num) { + fmt.Println(res) + fmt.Println(math.Pow(num, 2) + math.Sqrt(num)) + panic("wrong answer") + } + } + } +} + +func receiveResultFromNodeD(res *float64, fg *TimeTickedFlowGraph, wg *sync.WaitGroup) { + d := nodeD{} + node := fg.nodeCtx[d.Name()] + nd, ok := (*node.node).(*nodeD) + if !ok { + log.Fatal("not nodeD type") + } + *res = <-nd.resChan + wg.Done() +} + +func receiveResult(ctx context.Context, fg *TimeTickedFlowGraph) (float64, bool) { + d := nodeD{} + node := fg.nodeCtx[d.Name()] + nd, ok := (*node.node).(*nodeD) + if !ok { + log.Fatal("not nodeD type") + } + select { + case <-ctx.Done(): + return 0, false + case res := <-nd.resChan: + return res, true + } +} + +func TestTimeTickedFlowGraph_Start(t *testing.T) { + duration := time.Now().Add(ctxTimeInMillisecond * time.Millisecond) + ctx, _ := context.WithDeadline(context.Background(), duration) + fg := NewTimeTickedFlowGraph(ctx) + + var a Node = &nodeA{ + baseNode: baseNode{ + maxQueueLength: maxQueueLength, + }, + } + var b Node = &nodeB{ + baseNode: baseNode{ + maxQueueLength: maxQueueLength, + }, + } + var c Node = &nodeC{ + baseNode: baseNode{ + maxQueueLength: maxQueueLength, + }, + } + var d Node = &nodeD{ + baseNode: baseNode{ + maxQueueLength: maxQueueLength, + }, + resChan: make(chan float64), + } + + fg.AddNode(&a) + fg.AddNode(&b) + fg.AddNode(&c) + fg.AddNode(&d) + + var err = fg.SetEdges(a.Name(), + []string{}, + []string{b.Name(), c.Name()}, + ) + if err != nil { + log.Fatal("set edges failed") + } + + err = fg.SetEdges(b.Name(), + []string{a.Name()}, + []string{d.Name()}, + ) + if err != nil { + log.Fatal("set edges failed") + } + + err = fg.SetEdges(c.Name(), + []string{a.Name()}, + []string{d.Name()}, + ) + if err != nil { + log.Fatal("set edges failed") + } + + err = fg.SetEdges(d.Name(), + []string{b.Name(), c.Name()}, + []string{}, + ) + if err != nil { + log.Fatal("set edges failed") + } + + // init node A + nodeCtxA := fg.nodeCtx[a.Name()] + nodeCtxA.inputChannels = []chan *Msg{make(chan *Msg, 10)} + + go fg.Start() + + sendMsgFromCmd(ctx, fg) +} diff --git a/internal/util/flowgraph/message.go b/internal/util/flowgraph/message.go new file mode 100644 index 0000000000..6abf8c2c24 --- /dev/null +++ b/internal/util/flowgraph/message.go @@ -0,0 +1,6 @@ +package flowgraph + +type Msg interface { + TimeTick() Timestamp + DownStreamNodeIdx() int32 +} diff --git a/internal/util/flowgraph/node.go b/internal/util/flowgraph/node.go new file mode 100644 index 0000000000..5abcd5f83c --- /dev/null +++ b/internal/util/flowgraph/node.go @@ -0,0 +1,118 @@ +package flowgraph + +import ( + "context" + "fmt" + "sync" +) + +const maxQueueLength = 1024 + +type Node interface { + Name() string + MaxQueueLength() int32 + MaxParallelism() int32 + SetPipelineStates(states *flowGraphStates) + Operate(in []*Msg) []*Msg +} + +type baseNode struct { + maxQueueLength int32 + maxParallelism int32 + graphStates *flowGraphStates +} + +type nodeCtx struct { + node *Node + inputChannels []chan *Msg + inputMessages [][]*Msg + downstream []*nodeCtx + downstreamInputChanIdx map[string]int +} + +func (nodeCtx *nodeCtx) Start(ctx context.Context, wg *sync.WaitGroup) { + for { + select { + case <-ctx.Done(): + wg.Done() + return + default: + if !nodeCtx.allUpstreamDone() { + continue + } + nodeCtx.getMessagesFromChannel() + // inputs from inputsMessages for Operate + inputs := make([]*Msg, 0) + for i := 0; i < len(nodeCtx.inputMessages); i++ { + inputs = append(inputs, nodeCtx.inputMessages[i]...) + } + n := *nodeCtx.node + res := n.Operate(inputs) + wg := sync.WaitGroup{} + for i := 0; i < len(nodeCtx.downstreamInputChanIdx); i++ { + wg.Add(1) + go nodeCtx.downstream[i].ReceiveMsg(&wg, res[i], nodeCtx.downstreamInputChanIdx[(*nodeCtx.downstream[i].node).Name()]) + } + wg.Wait() + } + } +} + +func (nodeCtx *nodeCtx) Close() { + for _, channel := range nodeCtx.inputChannels { + close(channel) + } +} + +func (nodeCtx *nodeCtx) ReceiveMsg(wg *sync.WaitGroup, msg *Msg, inputChanIdx int) { + nodeCtx.inputChannels[inputChanIdx] <- msg + fmt.Println("node:", (*nodeCtx.node).Name(), "receive to input channel ", inputChanIdx) + wg.Done() +} + +func (nodeCtx *nodeCtx) allUpstreamDone() bool { + inputsNum := len(nodeCtx.inputChannels) + hasInputs := 0 + for i := 0; i < inputsNum; i++ { + channel := nodeCtx.inputChannels[i] + if len(channel) > 0 { + hasInputs++ + } + } + return hasInputs == inputsNum +} + +func (nodeCtx *nodeCtx) getMessagesFromChannel() { + inputsNum := len(nodeCtx.inputChannels) + nodeCtx.inputMessages = make([][]*Msg, inputsNum) + + // init inputMessages, + // receive messages from inputChannels, + // and move them to inputMessages. + for i := 0; i < inputsNum; i++ { + nodeCtx.inputMessages[i] = make([]*Msg, 0) + channel := nodeCtx.inputChannels[i] + msg := <-channel + nodeCtx.inputMessages[i] = append(nodeCtx.inputMessages[i], msg) + } +} + +func (node *baseNode) MaxQueueLength() int32 { + return node.maxQueueLength +} + +func (node *baseNode) MaxParallelism() int32 { + return node.maxParallelism +} + +func (node *baseNode) SetMaxQueueLength(n int32) { + node.maxQueueLength = n +} + +func (node *baseNode) SetMaxParallelism(n int32) { + node.maxParallelism = n +} + +func (node *baseNode) SetPipelineStates(states *flowGraphStates) { + node.graphStates = states +}