github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/fsm/match.go (about) 1 // Copyright 2017 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package fsm 12 13 import ( 14 "fmt" 15 "reflect" 16 ) 17 18 var ( 19 // True is a pattern that matches true booleans. 20 True Bool = b(true) 21 // False is a pattern that matches false booleans. 22 False Bool = b(false) 23 // Any is a pattern that matches any value. 24 Any = Var("") 25 ) 26 27 // Bool represents a boolean pattern. 28 type Bool interface { 29 bool() 30 // Get returns the value of a Bool. 31 Get() bool 32 } 33 34 // FromBool creates a Bool from a Go bool. 35 func FromBool(val bool) Bool { 36 return b(val) 37 } 38 39 type b bool 40 41 // Var allows variables to be bound to names and used in the match expression. 42 // If the variable binding name is the empty string, it acts as a wildcard but 43 // does not add any variable to the expression scope. 44 type Var string 45 46 func (b) bool() {} 47 func (x b) Get() bool { return bool(x) } 48 func (Var) bool() {} 49 50 // Get is part of the Bool interface. 51 func (Var) Get() bool { panic("can't call get on Var") } 52 53 // Pattern is a mapping from (State,Event) pairs to Transitions. When 54 // unexpanded, it may contain values like wildcards and variable bindings. 55 type Pattern map[State]map[Event]Transition 56 57 // expandPattern expands the States and Events in a Pattern to produce a new 58 // Pattern with no wildcards or variable bindings. For example: 59 // 60 // Pattern{ 61 // state3{Any}: { 62 // event1{}: {state2{}, ...}, 63 // }, 64 // state1{}: { 65 // event4{Any, Var("x")}: {state3{Var("x")}, ...}, 66 // }, 67 // } 68 // 69 // is expanded to: 70 // 71 // Pattern{ 72 // state3{False}: { 73 // event1{}: {state2{}, ...}, 74 // }, 75 // state3{True}: { 76 // event1{}: {state2{}, ...}, 77 // }, 78 // state1{}: { 79 // event4{False, False}: {state3{False}, ...}, 80 // event4{False, True}: {state3{True}, ...}, 81 // event4{True, False}: {state3{False}, ...}, 82 // event4{True, True}: {state3{True}, ...}, 83 // }, 84 // } 85 // 86 func expandPattern(p Pattern) Pattern { 87 xp := make(Pattern) 88 for s, sm := range p { 89 sVars := expandState(s) 90 for _, sVar := range sVars { 91 xs := sVar.v.Interface().(State) 92 93 xsm := xp[xs] 94 if xsm == nil { 95 xsm = make(map[Event]Transition) 96 xp[xs] = xsm 97 } 98 99 for e, t := range sm { 100 eVars := expandEvent(e) 101 for _, eVar := range eVars { 102 xe := eVar.v.Interface().(Event) 103 if _, ok := xsm[xe]; ok { 104 panic("match patterns overlap") 105 } 106 107 scope := mergeScope(sVar.scope, eVar.scope) 108 xsm[xe] = Transition{ 109 Next: bindState(t.Next, scope), 110 Action: t.Action, 111 Description: t.Description, 112 } 113 } 114 } 115 } 116 } 117 return xp 118 } 119 120 type bindings map[string]reflect.Value 121 type expandedVar struct { 122 v reflect.Value 123 scope bindings 124 } 125 126 func expandState(s State) []expandedVar { 127 if s == nil { 128 panic("found nil state") 129 } 130 return expandVar(reflect.ValueOf(s)) 131 } 132 func expandEvent(e Event) []expandedVar { 133 if e == nil { 134 panic("found nil event") 135 } 136 return expandVar(reflect.ValueOf(e)) 137 } 138 139 // expand expands all wildcards in the provided value. 140 func expandVar(v reflect.Value) []expandedVar { 141 for i := 0; i < v.NumField(); i++ { 142 f := v.Field(i).Interface() 143 switch t := f.(type) { 144 case nil: 145 panic("found nil field in match pattern") 146 case Bool: 147 switch bt := t.(type) { 148 case b: 149 // Can't expand. 150 case Var: 151 var xPats []expandedVar 152 for _, xVal := range expandBool(v, i) { 153 // xVal has its ith field set to a concrete value. We then 154 // recurse to expand any other wildcards or bindings. 155 recXPats := expandVar(xVal) 156 if bt != Any { 157 for j, rexXPat := range recXPats { 158 recXPats[j].scope = mergeScope( 159 rexXPat.scope, 160 bindings{string(bt): xVal.Field(i)}, 161 ) 162 } 163 } 164 xPats = append(xPats, recXPats...) 165 } 166 return xPats 167 default: 168 panic("unexpected Bool variant") 169 } 170 default: 171 // Can't expand. 172 } 173 } 174 return []expandedVar{{v: v}} 175 } 176 177 func expandBool(v reflect.Value, field int) []reflect.Value { 178 vTrue := reflect.New(v.Type()).Elem() 179 vTrue.Set(v) 180 vTrue.Field(field).Set(reflect.ValueOf(True)) 181 182 vFalse := reflect.New(v.Type()).Elem() 183 vFalse.Set(v) 184 vFalse.Field(field).Set(reflect.ValueOf(False)) 185 186 return []reflect.Value{vTrue, vFalse} 187 } 188 189 func bindState(s State, scope bindings) State { 190 if s == nil { 191 panic("found nil state") 192 } 193 xS := bindVar(reflect.ValueOf(s), scope) 194 return xS.Interface().(State) 195 } 196 197 // bindVar binds all variables in the provided value based on the variables in 198 // the scope. 199 func bindVar(v reflect.Value, scope bindings) reflect.Value { 200 newV := reflect.New(v.Type()).Elem() 201 newV.Set(v) 202 for i := 0; i < newV.NumField(); i++ { 203 f := newV.Field(i).Interface() 204 switch t := f.(type) { 205 case nil: 206 panic("found nil field in match expr") 207 case Bool: 208 switch bt := t.(type) { 209 case b: 210 // Nothing to bind. 211 case Var: 212 name := string(bt) 213 if name == "" { 214 panic("wildcard found in match expr") 215 } 216 if bv, ok := scope[name]; ok { 217 newV.Field(i).Set(bv) 218 } else { 219 panic(fmt.Sprintf("no binding for %q", name)) 220 } 221 default: 222 panic("unexpected Bool variant") 223 } 224 default: 225 // Nothing to bind. 226 } 227 } 228 return newV 229 } 230 231 func mergeScope(a, b bindings) bindings { 232 if len(a) == 0 && len(b) == 0 { 233 return nil 234 } 235 merged := make(bindings, len(a)+len(b)) 236 for n, v := range a { 237 merged[n] = v 238 } 239 for n, v := range b { 240 if _, ok := merged[n]; ok { 241 panic(fmt.Sprintf("multiple bindings for %q", n)) 242 } 243 merged[n] = v 244 } 245 return merged 246 }