github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/fsm/match.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package fsm
    12  
    13  import (
    14  	"fmt"
    15  	"reflect"
    16  )
    17  
    18  var (
    19  	// True is a pattern that matches true booleans.
    20  	True Bool = b(true)
    21  	// False is a pattern that matches false booleans.
    22  	False Bool = b(false)
    23  	// Any is a pattern that matches any value.
    24  	Any = Var("")
    25  )
    26  
    27  // Bool represents a boolean pattern.
    28  type Bool interface {
    29  	bool()
    30  	// Get returns the value of a Bool.
    31  	Get() bool
    32  }
    33  
    34  // FromBool creates a Bool from a Go bool.
    35  func FromBool(val bool) Bool {
    36  	return b(val)
    37  }
    38  
    39  type b bool
    40  
    41  // Var allows variables to be bound to names and used in the match expression.
    42  // If the variable binding name is the empty string, it acts as a wildcard but
    43  // does not add any variable to the expression scope.
    44  type Var string
    45  
    46  func (b) bool()       {}
    47  func (x b) Get() bool { return bool(x) }
    48  func (Var) bool()     {}
    49  
    50  // Get is part of the Bool interface.
    51  func (Var) Get() bool { panic("can't call get on Var") }
    52  
    53  // Pattern is a mapping from (State,Event) pairs to Transitions. When
    54  // unexpanded, it may contain values like wildcards and variable bindings.
    55  type Pattern map[State]map[Event]Transition
    56  
    57  // expandPattern expands the States and Events in a Pattern to produce a new
    58  // Pattern with no wildcards or variable bindings. For example:
    59  //
    60  //   Pattern{
    61  //       state3{Any}: {
    62  //           event1{}: {state2{}, ...},
    63  //       },
    64  //       state1{}: {
    65  //           event4{Any, Var("x")}: {state3{Var("x")}, ...},
    66  //       },
    67  //   }
    68  //
    69  // is expanded to:
    70  //
    71  //   Pattern{
    72  //       state3{False}: {
    73  //           event1{}: {state2{}, ...},
    74  //       },
    75  //       state3{True}: {
    76  //           event1{}: {state2{}, ...},
    77  //       },
    78  //       state1{}: {
    79  //           event4{False, False}: {state3{False}, ...},
    80  //           event4{False, True}:  {state3{True},  ...},
    81  //           event4{True, False}:  {state3{False}, ...},
    82  //           event4{True, True}:   {state3{True},  ...},
    83  //       },
    84  //   }
    85  //
    86  func expandPattern(p Pattern) Pattern {
    87  	xp := make(Pattern)
    88  	for s, sm := range p {
    89  		sVars := expandState(s)
    90  		for _, sVar := range sVars {
    91  			xs := sVar.v.Interface().(State)
    92  
    93  			xsm := xp[xs]
    94  			if xsm == nil {
    95  				xsm = make(map[Event]Transition)
    96  				xp[xs] = xsm
    97  			}
    98  
    99  			for e, t := range sm {
   100  				eVars := expandEvent(e)
   101  				for _, eVar := range eVars {
   102  					xe := eVar.v.Interface().(Event)
   103  					if _, ok := xsm[xe]; ok {
   104  						panic("match patterns overlap")
   105  					}
   106  
   107  					scope := mergeScope(sVar.scope, eVar.scope)
   108  					xsm[xe] = Transition{
   109  						Next:        bindState(t.Next, scope),
   110  						Action:      t.Action,
   111  						Description: t.Description,
   112  					}
   113  				}
   114  			}
   115  		}
   116  	}
   117  	return xp
   118  }
   119  
   120  type bindings map[string]reflect.Value
   121  type expandedVar struct {
   122  	v     reflect.Value
   123  	scope bindings
   124  }
   125  
   126  func expandState(s State) []expandedVar {
   127  	if s == nil {
   128  		panic("found nil state")
   129  	}
   130  	return expandVar(reflect.ValueOf(s))
   131  }
   132  func expandEvent(e Event) []expandedVar {
   133  	if e == nil {
   134  		panic("found nil event")
   135  	}
   136  	return expandVar(reflect.ValueOf(e))
   137  }
   138  
   139  // expand expands all wildcards in the provided value.
   140  func expandVar(v reflect.Value) []expandedVar {
   141  	for i := 0; i < v.NumField(); i++ {
   142  		f := v.Field(i).Interface()
   143  		switch t := f.(type) {
   144  		case nil:
   145  			panic("found nil field in match pattern")
   146  		case Bool:
   147  			switch bt := t.(type) {
   148  			case b:
   149  				// Can't expand.
   150  			case Var:
   151  				var xPats []expandedVar
   152  				for _, xVal := range expandBool(v, i) {
   153  					// xVal has its ith field set to a concrete value. We then
   154  					// recurse to expand any other wildcards or bindings.
   155  					recXPats := expandVar(xVal)
   156  					if bt != Any {
   157  						for j, rexXPat := range recXPats {
   158  							recXPats[j].scope = mergeScope(
   159  								rexXPat.scope,
   160  								bindings{string(bt): xVal.Field(i)},
   161  							)
   162  						}
   163  					}
   164  					xPats = append(xPats, recXPats...)
   165  				}
   166  				return xPats
   167  			default:
   168  				panic("unexpected Bool variant")
   169  			}
   170  		default:
   171  			// Can't expand.
   172  		}
   173  	}
   174  	return []expandedVar{{v: v}}
   175  }
   176  
   177  func expandBool(v reflect.Value, field int) []reflect.Value {
   178  	vTrue := reflect.New(v.Type()).Elem()
   179  	vTrue.Set(v)
   180  	vTrue.Field(field).Set(reflect.ValueOf(True))
   181  
   182  	vFalse := reflect.New(v.Type()).Elem()
   183  	vFalse.Set(v)
   184  	vFalse.Field(field).Set(reflect.ValueOf(False))
   185  
   186  	return []reflect.Value{vTrue, vFalse}
   187  }
   188  
   189  func bindState(s State, scope bindings) State {
   190  	if s == nil {
   191  		panic("found nil state")
   192  	}
   193  	xS := bindVar(reflect.ValueOf(s), scope)
   194  	return xS.Interface().(State)
   195  }
   196  
   197  // bindVar binds all variables in the provided value based on the variables in
   198  // the scope.
   199  func bindVar(v reflect.Value, scope bindings) reflect.Value {
   200  	newV := reflect.New(v.Type()).Elem()
   201  	newV.Set(v)
   202  	for i := 0; i < newV.NumField(); i++ {
   203  		f := newV.Field(i).Interface()
   204  		switch t := f.(type) {
   205  		case nil:
   206  			panic("found nil field in match expr")
   207  		case Bool:
   208  			switch bt := t.(type) {
   209  			case b:
   210  				// Nothing to bind.
   211  			case Var:
   212  				name := string(bt)
   213  				if name == "" {
   214  					panic("wildcard found in match expr")
   215  				}
   216  				if bv, ok := scope[name]; ok {
   217  					newV.Field(i).Set(bv)
   218  				} else {
   219  					panic(fmt.Sprintf("no binding for %q", name))
   220  				}
   221  			default:
   222  				panic("unexpected Bool variant")
   223  			}
   224  		default:
   225  			// Nothing to bind.
   226  		}
   227  	}
   228  	return newV
   229  }
   230  
   231  func mergeScope(a, b bindings) bindings {
   232  	if len(a) == 0 && len(b) == 0 {
   233  		return nil
   234  	}
   235  	merged := make(bindings, len(a)+len(b))
   236  	for n, v := range a {
   237  		merged[n] = v
   238  	}
   239  	for n, v := range b {
   240  		if _, ok := merged[n]; ok {
   241  			panic(fmt.Sprintf("multiple bindings for %q", n))
   242  		}
   243  		merged[n] = v
   244  	}
   245  	return merged
   246  }