github.com/metaworking/channeld@v0.7.3/pkg/fsm/fsm.go (about)

     1  package fsm
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  
    10  	"go.uber.org/zap"
    11  )
    12  
    13  type State struct {
    14  	Name             string
    15  	MsgTypeWhitelist string // example: "1, 2-10, 30"
    16  	MsgTypeBlacklist string
    17  
    18  	allowedMsgTypes map[uint32]bool
    19  	transitions     map[uint32]*State
    20  }
    21  
    22  type StateTransition struct {
    23  	FromState string
    24  	ToState   string
    25  	MsgType   uint32
    26  }
    27  
    28  type FiniteStateMachine struct {
    29  	InitState   *string
    30  	States      []State
    31  	Transitions []StateTransition
    32  
    33  	currentState *State
    34  	stateNameMap map[string]*State
    35  	lock         *sync.RWMutex
    36  }
    37  
    38  func parseMsgTypes(s string, f func(msgType uint32)) {
    39  	if len(s) == 0 {
    40  		return
    41  	}
    42  
    43  	for _, seg := range strings.Split(s, ",") {
    44  		seg = strings.Trim(seg, " ")
    45  		fromTo := strings.Split(seg, "-")
    46  		if len(fromTo) == 1 {
    47  			msgType, err := strconv.ParseUint(fromTo[0], 10, 32)
    48  			if err != nil {
    49  				logger.Errorf("Can't convert '%s' to uint32\n", fromTo[0])
    50  			} else {
    51  				f(uint32(msgType))
    52  			}
    53  		} else {
    54  			fromType, err := strconv.ParseUint(fromTo[0], 10, 32)
    55  			if err != nil {
    56  				logger.Errorf("Can't convert '%s' to uint32\n", fromTo[0])
    57  			} else {
    58  				toType, err := strconv.ParseUint(fromTo[1], 10, 32)
    59  				if err != nil {
    60  					logger.Errorf("Can't convert '%s' to uint32\n", fromTo[1])
    61  				} else {
    62  					for i := fromType; i <= toType; i += 1 {
    63  						f(uint32(i))
    64  					}
    65  				}
    66  
    67  			}
    68  		}
    69  	}
    70  }
    71  
    72  var logger *zap.SugaredLogger
    73  
    74  func Load(bytes []byte) (*FiniteStateMachine, error) {
    75  	if logger == nil {
    76  		l, _ := zap.NewProduction()
    77  		defer l.Sync()
    78  		logger = l.Sugar()
    79  	}
    80  
    81  	var fsm FiniteStateMachine
    82  	err := json.Unmarshal(bytes, &fsm)
    83  	if err == nil {
    84  		if len(fsm.States) > 0 {
    85  			fsm.currentState = &fsm.States[0]
    86  		}
    87  		fsm.stateNameMap = make(map[string]*State, len(fsm.States))
    88  		fsm.lock = &sync.RWMutex{}
    89  
    90  		for idx := range fsm.States {
    91  			state := &fsm.States[idx]
    92  			state.allowedMsgTypes = make(map[uint32]bool)
    93  			state.transitions = make(map[uint32]*State)
    94  			fsm.stateNameMap[state.Name] = state
    95  			parseMsgTypes(state.MsgTypeWhitelist, func(msgType uint32) {
    96  				state.allowedMsgTypes[msgType] = true
    97  			})
    98  			parseMsgTypes(state.MsgTypeBlacklist, func(msgType uint32) {
    99  				state.allowedMsgTypes[msgType] = false
   100  			})
   101  		}
   102  
   103  		for _, transition := range fsm.Transitions {
   104  			fromState, exists := fsm.stateNameMap[transition.FromState]
   105  			if !exists {
   106  				logger.Errorf("invalid FromState in StateTransition: %s -> %s (%d)\n", transition.FromState, transition.ToState, transition.MsgType)
   107  				continue
   108  			}
   109  			toState, exists := fsm.stateNameMap[transition.ToState]
   110  			if !exists {
   111  				logger.Errorf("invalid ToState in StateTransition: %s -> %s (%d)\n", transition.FromState, transition.ToState, transition.MsgType)
   112  				continue
   113  			}
   114  			fromState.transitions[transition.MsgType] = toState
   115  		}
   116  
   117  		if fsm.InitState != nil {
   118  			err = fsm.ChangeState(*fsm.InitState)
   119  		}
   120  	}
   121  	return &fsm, err
   122  }
   123  
   124  func (fsm *FiniteStateMachine) IsAllowed(msgType uint32) bool {
   125  	fsm.lock.RLock()
   126  	defer fsm.lock.RUnlock()
   127  
   128  	return fsm.currentState.allowedMsgTypes[msgType]
   129  }
   130  
   131  func (fsm *FiniteStateMachine) OnReceived(msgType uint32) {
   132  	fsm.lock.Lock()
   133  	defer fsm.lock.Unlock()
   134  
   135  	newState := fsm.currentState.transitions[msgType]
   136  	if newState != nil {
   137  		fsm.currentState = newState
   138  	}
   139  }
   140  
   141  func (fsm *FiniteStateMachine) CurrentState() *State {
   142  	fsm.lock.RLock()
   143  	defer fsm.lock.RUnlock()
   144  
   145  	return fsm.currentState
   146  }
   147  
   148  func (fsm *FiniteStateMachine) ChangeState(name string) error {
   149  	fsm.lock.Lock()
   150  	defer fsm.lock.Unlock()
   151  
   152  	state, exists := fsm.stateNameMap[name]
   153  	if exists {
   154  		fsm.currentState = state
   155  		return nil
   156  	}
   157  	return errors.New("Invalid state name: " + name)
   158  }
   159  
   160  func (fsm *FiniteStateMachine) MoveToNextState() bool {
   161  	fsm.lock.Lock()
   162  	defer fsm.lock.Unlock()
   163  
   164  	for i := 0; i < len(fsm.States)-1; i++ {
   165  		if fsm.currentState == &fsm.States[i] {
   166  			fsm.currentState = &fsm.States[i+1]
   167  			return true
   168  		}
   169  	}
   170  	return false
   171  }