github.com/metaworking/channeld@v0.7.3/pkg/fsm/fsm.go (about) 1 package fsm 2 3 import ( 4 "encoding/json" 5 "errors" 6 "strconv" 7 "strings" 8 "sync" 9 10 "go.uber.org/zap" 11 ) 12 13 type State struct { 14 Name string 15 MsgTypeWhitelist string // example: "1, 2-10, 30" 16 MsgTypeBlacklist string 17 18 allowedMsgTypes map[uint32]bool 19 transitions map[uint32]*State 20 } 21 22 type StateTransition struct { 23 FromState string 24 ToState string 25 MsgType uint32 26 } 27 28 type FiniteStateMachine struct { 29 InitState *string 30 States []State 31 Transitions []StateTransition 32 33 currentState *State 34 stateNameMap map[string]*State 35 lock *sync.RWMutex 36 } 37 38 func parseMsgTypes(s string, f func(msgType uint32)) { 39 if len(s) == 0 { 40 return 41 } 42 43 for _, seg := range strings.Split(s, ",") { 44 seg = strings.Trim(seg, " ") 45 fromTo := strings.Split(seg, "-") 46 if len(fromTo) == 1 { 47 msgType, err := strconv.ParseUint(fromTo[0], 10, 32) 48 if err != nil { 49 logger.Errorf("Can't convert '%s' to uint32\n", fromTo[0]) 50 } else { 51 f(uint32(msgType)) 52 } 53 } else { 54 fromType, err := strconv.ParseUint(fromTo[0], 10, 32) 55 if err != nil { 56 logger.Errorf("Can't convert '%s' to uint32\n", fromTo[0]) 57 } else { 58 toType, err := strconv.ParseUint(fromTo[1], 10, 32) 59 if err != nil { 60 logger.Errorf("Can't convert '%s' to uint32\n", fromTo[1]) 61 } else { 62 for i := fromType; i <= toType; i += 1 { 63 f(uint32(i)) 64 } 65 } 66 67 } 68 } 69 } 70 } 71 72 var logger *zap.SugaredLogger 73 74 func Load(bytes []byte) (*FiniteStateMachine, error) { 75 if logger == nil { 76 l, _ := zap.NewProduction() 77 defer l.Sync() 78 logger = l.Sugar() 79 } 80 81 var fsm FiniteStateMachine 82 err := json.Unmarshal(bytes, &fsm) 83 if err == nil { 84 if len(fsm.States) > 0 { 85 fsm.currentState = &fsm.States[0] 86 } 87 fsm.stateNameMap = make(map[string]*State, len(fsm.States)) 88 fsm.lock = &sync.RWMutex{} 89 90 for idx := range fsm.States { 91 state := &fsm.States[idx] 92 state.allowedMsgTypes = make(map[uint32]bool) 93 state.transitions = make(map[uint32]*State) 94 fsm.stateNameMap[state.Name] = state 95 parseMsgTypes(state.MsgTypeWhitelist, func(msgType uint32) { 96 state.allowedMsgTypes[msgType] = true 97 }) 98 parseMsgTypes(state.MsgTypeBlacklist, func(msgType uint32) { 99 state.allowedMsgTypes[msgType] = false 100 }) 101 } 102 103 for _, transition := range fsm.Transitions { 104 fromState, exists := fsm.stateNameMap[transition.FromState] 105 if !exists { 106 logger.Errorf("invalid FromState in StateTransition: %s -> %s (%d)\n", transition.FromState, transition.ToState, transition.MsgType) 107 continue 108 } 109 toState, exists := fsm.stateNameMap[transition.ToState] 110 if !exists { 111 logger.Errorf("invalid ToState in StateTransition: %s -> %s (%d)\n", transition.FromState, transition.ToState, transition.MsgType) 112 continue 113 } 114 fromState.transitions[transition.MsgType] = toState 115 } 116 117 if fsm.InitState != nil { 118 err = fsm.ChangeState(*fsm.InitState) 119 } 120 } 121 return &fsm, err 122 } 123 124 func (fsm *FiniteStateMachine) IsAllowed(msgType uint32) bool { 125 fsm.lock.RLock() 126 defer fsm.lock.RUnlock() 127 128 return fsm.currentState.allowedMsgTypes[msgType] 129 } 130 131 func (fsm *FiniteStateMachine) OnReceived(msgType uint32) { 132 fsm.lock.Lock() 133 defer fsm.lock.Unlock() 134 135 newState := fsm.currentState.transitions[msgType] 136 if newState != nil { 137 fsm.currentState = newState 138 } 139 } 140 141 func (fsm *FiniteStateMachine) CurrentState() *State { 142 fsm.lock.RLock() 143 defer fsm.lock.RUnlock() 144 145 return fsm.currentState 146 } 147 148 func (fsm *FiniteStateMachine) ChangeState(name string) error { 149 fsm.lock.Lock() 150 defer fsm.lock.Unlock() 151 152 state, exists := fsm.stateNameMap[name] 153 if exists { 154 fsm.currentState = state 155 return nil 156 } 157 return errors.New("Invalid state name: " + name) 158 } 159 160 func (fsm *FiniteStateMachine) MoveToNextState() bool { 161 fsm.lock.Lock() 162 defer fsm.lock.Unlock() 163 164 for i := 0; i < len(fsm.States)-1; i++ { 165 if fsm.currentState == &fsm.States[i] { 166 fsm.currentState = &fsm.States[i+1] 167 return true 168 } 169 } 170 return false 171 }