github.com/15mga/kiwi@v0.0.2-0.20240324021231-b95d5c3ac751/graph/fsm_test.go (about)

     1  package graph
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/15mga/kiwi"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/15mga/kiwi/util"
    11  	"github.com/15mga/kiwi/worker"
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  const (
    16  	PTimout = "timeout"
    17  )
    18  
    19  func NewFsmPlugin() *fsmPlugin {
    20  	return &fsmPlugin{
    21  		pathToTimeout: make(map[string]time.Duration),
    22  		manualDisable: make(map[string]struct{}),
    23  		autoEnable:    make(map[string]INode),
    24  		worker:        worker.NewFnWorker(),
    25  	}
    26  }
    27  
    28  type fsmPlugin struct {
    29  	worker        *worker.FnWorker
    30  	pathToTimeout map[string]time.Duration //激活节点超时关闭节点
    31  	manualDisable map[string]struct{}      //手动关闭,节点输出后不自动关闭
    32  	autoEnable    map[string]INode
    33  }
    34  
    35  func (s *fsmPlugin) SetNodeTimeout(dur time.Duration, path ...string) {
    36  	s.pathToTimeout[strings.Join(path, ".")] = dur
    37  }
    38  
    39  func (s *fsmPlugin) SetManualDisable(path ...string) {
    40  	s.manualDisable[strings.Join(path, ".")] = struct{}{}
    41  }
    42  
    43  func (s *fsmPlugin) SetAutoEnable(path ...string) {
    44  	s.autoEnable[strings.Join(path, ".")] = nil
    45  }
    46  
    47  func (s *fsmPlugin) OnStart(g IGraph) {
    48  	s.worker.Start()
    49  }
    50  
    51  func (s *fsmPlugin) OnAddIn(in IIn) {
    52  	in.AddFilter(func(msg IMsg) *util.Err {
    53  		inNode := in.Node()
    54  		_, ok := s.manualDisable[inNode.Path()]
    55  		if !ok {
    56  			err := inNode.SetEnable(true)
    57  			if err != nil {
    58  				kiwi.Error(err)
    59  			}
    60  		}
    61  		return nil
    62  	})
    63  }
    64  
    65  func (s *fsmPlugin) OnAddOut(out IOut) {
    66  	out.AddFilter(func(msg IMsg) *util.Err {
    67  		outNode := out.Node()
    68  		_, ok := s.manualDisable[outNode.Path()]
    69  		if !ok {
    70  			err := outNode.SetEnable(false)
    71  			if err != nil {
    72  				kiwi.Error(err)
    73  			}
    74  		}
    75  		return nil
    76  	})
    77  }
    78  
    79  func (s *fsmPlugin) OnAddNode(g IGraph, nd INode) {
    80  	nd.SetData(util.M{})
    81  
    82  	dur, ok := s.pathToTimeout[nd.Path()]
    83  	if ok {
    84  		_ = nd.AddOut(TpNone, PTimout)
    85  		nd.AddAfterEnable(func(enable bool) {
    86  			if enable {
    87  				nd.Data().Set(PTimout, time.AfterFunc(dur, func() {
    88  					s.worker.Push(func(params []any) {
    89  						nd := params[0].(INode)
    90  						err := nd.Out(PTimout, nil)
    91  						if err != nil {
    92  							kiwi.Error(err)
    93  						}
    94  						nd.Data().Del(PTimout)
    95  					}, nd)
    96  				}))
    97  			} else {
    98  				timer, exit := util.MPop[time.Timer](nd.Data(), PTimout)
    99  				if exit {
   100  					timer.Stop()
   101  				}
   102  			}
   103  		})
   104  	}
   105  
   106  	_, ok = s.autoEnable[nd.Path()]
   107  	_ = nd.SetEnable(ok)
   108  }
   109  
   110  func (s *fsmPlugin) OnAddSubGraph(g IGraph, sg ISubGraph) {
   111  
   112  }
   113  
   114  func (s *fsmPlugin) OnAddLink(g IGraph, lnk ILink) {
   115  
   116  }
   117  
   118  func newFsmGraph() (IGraph, *fsmPlugin) {
   119  	fsm := NewFsmPlugin()
   120  	fsm.SetNodeTimeout(time.Second, "init")
   121  	fsm.SetAutoEnable("init")
   122  	g := NewGraph("test", Plugin(fsm))
   123  	return g, fsm
   124  }
   125  
   126  func TestFsmPlugin_SetNodeTimeout(t *testing.T) {
   127  	g, _ := newFsmGraph()
   128  	node, err := g.AddNode("init")
   129  	assert.Nil(t, err)
   130  	assert.NotNil(t, node)
   131  
   132  	op, err := node.GetOut(PTimout)
   133  	assert.Nil(t, err)
   134  	assert.NotNil(t, op)
   135  
   136  	ch := make(chan struct{})
   137  	timer := time.NewTimer(time.Millisecond * 1500)
   138  	startTime := time.Now()
   139  	op.AddFilter(func(msg IMsg) *util.Err {
   140  		bytes, err := msg.ToJson()
   141  		assert.Nil(t, err)
   142  		fmt.Println(util.BytesToStr(bytes), time.Since(startTime).Milliseconds())
   143  		ch <- struct{}{}
   144  		return nil
   145  	})
   146  	_ = g.Start()
   147  	select {
   148  	case <-timer.C:
   149  		t.Error("timeout")
   150  	case <-ch:
   151  		timer.Stop()
   152  	}
   153  }