github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/fsm/fsm_test.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package fsm
    12  
    13  import (
    14  	"context"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/require"
    18  )
    19  
    20  type state1 struct{}
    21  type state2 struct{}
    22  type state3 struct {
    23  	Field Bool
    24  }
    25  type state4 struct {
    26  	Field1 Bool
    27  	Field2 Bool
    28  }
    29  
    30  func (state1) State() {}
    31  func (state2) State() {}
    32  func (state3) State() {}
    33  func (state4) State() {}
    34  
    35  type event1 struct{}
    36  type event2 struct{}
    37  type event3 struct {
    38  	Field Bool
    39  }
    40  type event4 struct {
    41  	Field1 Bool
    42  	Field2 Bool
    43  }
    44  
    45  func (event1) Event() {}
    46  func (event2) Event() {}
    47  func (event3) Event() {}
    48  func (event4) Event() {}
    49  
    50  var noAction func(Args) error
    51  
    52  func noErr(f func(Args)) func(Args) error {
    53  	return func(a Args) error { f(a); return nil }
    54  }
    55  
    56  func (tr Transitions) applyWithoutErr(t *testing.T, a Args) State {
    57  	s, err := tr.apply(a)
    58  	require.Nil(t, err)
    59  	return s
    60  }
    61  func (tr Transitions) applyWithErr(t *testing.T, a Args) error {
    62  	_, err := tr.apply(a)
    63  	require.NotNil(t, err)
    64  	return err
    65  }
    66  
    67  func TestBasicTransitions(t *testing.T) {
    68  	trans := Compile(Pattern{
    69  		state1{}: {
    70  			event1{}: {state2{}, noAction, ""},
    71  			event2{}: {state1{}, noAction, ""},
    72  		},
    73  		state2{}: {
    74  			event1{}: {state1{}, noAction, ""},
    75  			event2{}: {state2{}, noAction, ""},
    76  		},
    77  	})
    78  
    79  	// Valid transitions.
    80  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state1{}, Event: event1{}}), state2{})
    81  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state1{}, Event: event2{}}), state1{})
    82  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state2{}, Event: event1{}}), state1{})
    83  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state2{}, Event: event2{}}), state2{})
    84  
    85  	// Invalid transitions.
    86  	notFoundErr := &TransitionNotFoundError{}
    87  	require.IsType(t, trans.applyWithErr(t, Args{Prev: state3{}, Event: event1{}}), notFoundErr)
    88  	require.IsType(t, trans.applyWithErr(t, Args{Prev: state1{}, Event: event3{}}), notFoundErr)
    89  }
    90  
    91  func TestTransitionActions(t *testing.T) {
    92  	var extendedState int
    93  	trans := Compile(Pattern{
    94  		state1{}: {
    95  			event1{}: {state2{}, noErr(func(a Args) { *a.Extended.(*int) = 1 }), ""},
    96  			event2{}: {state1{}, noErr(func(a Args) { *a.Extended.(*int) = 2 }), ""},
    97  		},
    98  		state2{}: {
    99  			event1{}: {state1{}, noErr(func(a Args) { *a.Extended.(*int) = 3 }), ""},
   100  			event2{}: {state2{}, noErr(func(a Args) { *a.Extended.(*int) = 4 }), ""},
   101  		},
   102  	})
   103  
   104  	trans.applyWithoutErr(t, Args{Prev: state1{}, Event: event1{}, Extended: &extendedState})
   105  	require.Equal(t, extendedState, 1)
   106  
   107  	trans.applyWithoutErr(t, Args{Prev: state1{}, Event: event2{}, Extended: &extendedState})
   108  	require.Equal(t, extendedState, 2)
   109  
   110  	trans.applyWithoutErr(t, Args{Prev: state2{}, Event: event1{}, Extended: &extendedState})
   111  	require.Equal(t, extendedState, 3)
   112  
   113  	trans.applyWithoutErr(t, Args{Prev: state2{}, Event: event2{}, Extended: &extendedState})
   114  	require.Equal(t, extendedState, 4)
   115  }
   116  
   117  func TestTransitionsWithWildcards(t *testing.T) {
   118  	trans := Compile(Pattern{
   119  		state3{Any}: {
   120  			event3{Any}: {state1{}, noAction, ""},
   121  		},
   122  	})
   123  
   124  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{True}, Event: event3{True}}), state1{})
   125  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{True}, Event: event3{False}}), state1{})
   126  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{False}, Event: event3{True}}), state1{})
   127  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{False}, Event: event3{False}}), state1{})
   128  }
   129  
   130  func TestTransitionsWithVarBindings(t *testing.T) {
   131  	trans := Compile(Pattern{
   132  		state3{Var("a")}: {
   133  			event3{Var("b")}: {state4{Var("b"), Var("a")}, noAction, ""},
   134  		},
   135  	})
   136  
   137  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{True}, Event: event3{True}}), state4{True, True})
   138  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{True}, Event: event3{False}}), state4{False, True})
   139  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{False}, Event: event3{True}}), state4{True, False})
   140  	require.Equal(t, trans.applyWithoutErr(t, Args{Prev: state3{False}, Event: event3{False}}), state4{False, False})
   141  }
   142  
   143  func TestCurState(t *testing.T) {
   144  	ctx := context.Background()
   145  	trans := Compile(Pattern{
   146  		state1{}: {
   147  			event1{}: {state2{}, func(a Args) error { return nil }, ""},
   148  		},
   149  	})
   150  	m := MakeMachine(trans, state1{}, nil /* es */)
   151  
   152  	e := Event(event1{})
   153  	if err := m.Apply(ctx, e); err != nil {
   154  		t.Fatal(err)
   155  	}
   156  	require.Equal(t, m.CurState(), state2{})
   157  }
   158  
   159  func BenchmarkPatternCompilation(b *testing.B) {
   160  	for i := 0; i < b.N; i++ {
   161  		_ = Compile(Pattern{
   162  			state1{}: {
   163  				event4{True, Any}:  {state2{}, noAction, ""},
   164  				event4{False, Any}: {state1{}, noAction, ""},
   165  			},
   166  			state2{}: {
   167  				event4{Any, Any}: {state2{}, noAction, ""},
   168  			},
   169  			state3{True}: {
   170  				event1{}: {state1{}, noAction, ""},
   171  			},
   172  			state3{False}: {
   173  				event3{True}:  {state2{}, noAction, ""},
   174  				event3{False}: {state1{}, noAction, ""},
   175  			},
   176  			state4{Var("x"), Var("y")}: {
   177  				event4{True, True}:   {state1{}, noAction, ""},
   178  				event4{True, False}:  {state2{}, noAction, ""},
   179  				event4{False, True}:  {state3{Var("x")}, noAction, ""},
   180  				event4{False, False}: {state4{Var("y"), Var("x")}, noAction, ""},
   181  			},
   182  		})
   183  	}
   184  }
   185  
   186  func BenchmarkStateTransition(b *testing.B) {
   187  	var extendedState int
   188  	ctx := context.Background()
   189  	trans := Compile(Pattern{
   190  		state1{}: {
   191  			event1{}: {state2{}, noErr(func(a Args) { *a.Extended.(*int) = 1 }), ""},
   192  			event2{}: {state1{}, noErr(func(a Args) { *a.Extended.(*int) = 2 }), ""},
   193  		},
   194  		state2{}: {
   195  			event1{}: {state1{}, noErr(func(a Args) { *a.Extended.(*int) = 3 }), ""},
   196  			event2{}: {state2{}, noErr(func(a Args) { *a.Extended.(*int) = 4 }), ""},
   197  		},
   198  		// Unused, but complicates transition graph. Demonstrates that a more
   199  		// complicated graph does not hurt runtime performance.
   200  		state3{True}: {
   201  			event1{}: {state1{}, noAction, ""},
   202  		},
   203  		state3{False}: {
   204  			event3{True}:  {state2{}, noAction, ""},
   205  			event3{False}: {state1{}, noAction, ""},
   206  		},
   207  		state4{Var("x"), Var("y")}: {
   208  			event4{True, True}:   {state1{}, noAction, ""},
   209  			event4{True, False}:  {state2{}, noAction, ""},
   210  			event4{False, True}:  {state3{Var("x")}, noAction, ""},
   211  			event4{False, False}: {state4{Var("y"), Var("x")}, noAction, ""},
   212  		},
   213  	})
   214  	m := MakeMachine(trans, state1{}, &extendedState)
   215  
   216  	b.ResetTimer()
   217  	for i := 0; i < b.N; i++ {
   218  		e := Event(event1{})
   219  		if i%2 == 1 {
   220  			e = event2{}
   221  		}
   222  		_ = m.ApplyWithPayload(ctx, e, 12)
   223  	}
   224  }