github.com/expr-lang/expr@v1.16.9/vm/runtime/helpers/main.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "fmt" 6 "go/format" 7 "os" 8 "strings" 9 "text/template" 10 ) 11 12 func main() { 13 var b bytes.Buffer 14 err := template.Must( 15 template.New("helpers"). 16 Funcs(template.FuncMap{ 17 "cases": func(op string) string { return cases(op, uints, ints, floats) }, 18 "cases_int_only": func(op string) string { return cases(op, uints, ints) }, 19 "cases_with_duration": func(op string) string { 20 return cases(op, uints, ints, floats, []string{"time.Duration"}) 21 }, 22 "array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) }, 23 }). 24 Parse(helpers), 25 ).Execute(&b, nil) 26 if err != nil { 27 panic(err) 28 } 29 30 formatted, err := format.Source(b.Bytes()) 31 if err != nil { 32 panic(err) 33 } 34 fmt.Print(string(formatted)) 35 } 36 37 var ints = []string{ 38 "int", 39 "int8", 40 "int16", 41 "int32", 42 "int64", 43 } 44 45 var uints = []string{ 46 "uint", 47 "uint8", 48 "uint16", 49 "uint32", 50 "uint64", 51 } 52 53 var floats = []string{ 54 "float32", 55 "float64", 56 } 57 58 func cases(op string, xs ...[]string) string { 59 var types []string 60 for _, x := range xs { 61 types = append(types, x...) 62 } 63 64 _, _ = fmt.Fprintf(os.Stderr, "Generating %s cases for %v\n", op, types) 65 66 var out string 67 echo := func(s string, xs ...any) { 68 out += fmt.Sprintf(s, xs...) + "\n" 69 } 70 for _, a := range types { 71 echo(`case %v:`, a) 72 echo(`switch y := b.(type) {`) 73 for _, b := range types { 74 t := "int" 75 if isDuration(a) || isDuration(b) { 76 t = "time.Duration" 77 } 78 if isFloat(a) || isFloat(b) { 79 t = "float64" 80 } 81 echo(`case %v:`, b) 82 if op == "/" { 83 echo(`return float64(x) / float64(y)`) 84 } else { 85 echo(`return %v(x) %v %v(y)`, t, op, t) 86 } 87 } 88 echo(`}`) 89 } 90 return strings.TrimRight(out, "\n") 91 } 92 93 func arrayEqualCases(xs ...[]string) string { 94 var types []string 95 for _, x := range xs { 96 types = append(types, x...) 97 } 98 99 _, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types) 100 101 var out string 102 echo := func(s string, xs ...any) { 103 out += fmt.Sprintf(s, xs...) + "\n" 104 } 105 echo(`case []any:`) 106 echo(`switch y := b.(type) {`) 107 for _, a := range append(types, "any") { 108 echo(`case []%v:`, a) 109 echo(`if len(x) != len(y) { return false }`) 110 echo(`for i := range x {`) 111 echo(`if !Equal(x[i], y[i]) { return false }`) 112 echo(`}`) 113 echo("return true") 114 } 115 echo(`}`) 116 for _, a := range types { 117 echo(`case []%v:`, a) 118 echo(`switch y := b.(type) {`) 119 echo(`case []any:`) 120 echo(`return Equal(y, x)`) 121 echo(`case []%v:`, a) 122 echo(`if len(x) != len(y) { return false }`) 123 echo(`for i := range x {`) 124 echo(`if x[i] != y[i] { return false }`) 125 echo(`}`) 126 echo("return true") 127 echo(`}`) 128 } 129 return strings.TrimRight(out, "\n") 130 } 131 132 func isFloat(t string) bool { 133 return strings.HasPrefix(t, "float") 134 } 135 136 func isDuration(t string) bool { 137 return t == "time.Duration" 138 } 139 140 const helpers = `// Code generated by vm/runtime/helpers/main.go. DO NOT EDIT. 141 142 package runtime 143 144 import ( 145 "fmt" 146 "reflect" 147 "time" 148 ) 149 150 func Equal(a, b interface{}) bool { 151 switch x := a.(type) { 152 {{ cases "==" }} 153 {{ array_equal_cases }} 154 case string: 155 switch y := b.(type) { 156 case string: 157 return x == y 158 } 159 case time.Time: 160 switch y := b.(type) { 161 case time.Time: 162 return x.Equal(y) 163 } 164 case time.Duration: 165 switch y := b.(type) { 166 case time.Duration: 167 return x == y 168 } 169 case bool: 170 switch y := b.(type) { 171 case bool: 172 return x == y 173 } 174 } 175 if IsNil(a) && IsNil(b) { 176 return true 177 } 178 return reflect.DeepEqual(a, b) 179 } 180 181 func Less(a, b interface{}) bool { 182 switch x := a.(type) { 183 {{ cases "<" }} 184 case string: 185 switch y := b.(type) { 186 case string: 187 return x < y 188 } 189 case time.Time: 190 switch y := b.(type) { 191 case time.Time: 192 return x.Before(y) 193 } 194 case time.Duration: 195 switch y := b.(type) { 196 case time.Duration: 197 return x < y 198 } 199 } 200 panic(fmt.Sprintf("invalid operation: %T < %T", a, b)) 201 } 202 203 func More(a, b interface{}) bool { 204 switch x := a.(type) { 205 {{ cases ">" }} 206 case string: 207 switch y := b.(type) { 208 case string: 209 return x > y 210 } 211 case time.Time: 212 switch y := b.(type) { 213 case time.Time: 214 return x.After(y) 215 } 216 case time.Duration: 217 switch y := b.(type) { 218 case time.Duration: 219 return x > y 220 } 221 } 222 panic(fmt.Sprintf("invalid operation: %T > %T", a, b)) 223 } 224 225 func LessOrEqual(a, b interface{}) bool { 226 switch x := a.(type) { 227 {{ cases "<=" }} 228 case string: 229 switch y := b.(type) { 230 case string: 231 return x <= y 232 } 233 case time.Time: 234 switch y := b.(type) { 235 case time.Time: 236 return x.Before(y) || x.Equal(y) 237 } 238 case time.Duration: 239 switch y := b.(type) { 240 case time.Duration: 241 return x <= y 242 } 243 } 244 panic(fmt.Sprintf("invalid operation: %T <= %T", a, b)) 245 } 246 247 func MoreOrEqual(a, b interface{}) bool { 248 switch x := a.(type) { 249 {{ cases ">=" }} 250 case string: 251 switch y := b.(type) { 252 case string: 253 return x >= y 254 } 255 case time.Time: 256 switch y := b.(type) { 257 case time.Time: 258 return x.After(y) || x.Equal(y) 259 } 260 case time.Duration: 261 switch y := b.(type) { 262 case time.Duration: 263 return x >= y 264 } 265 } 266 panic(fmt.Sprintf("invalid operation: %T >= %T", a, b)) 267 } 268 269 func Add(a, b interface{}) interface{} { 270 switch x := a.(type) { 271 {{ cases "+" }} 272 case string: 273 switch y := b.(type) { 274 case string: 275 return x + y 276 } 277 case time.Time: 278 switch y := b.(type) { 279 case time.Duration: 280 return x.Add(y) 281 } 282 case time.Duration: 283 switch y := b.(type) { 284 case time.Time: 285 return y.Add(x) 286 case time.Duration: 287 return x + y 288 } 289 } 290 panic(fmt.Sprintf("invalid operation: %T + %T", a, b)) 291 } 292 293 func Subtract(a, b interface{}) interface{} { 294 switch x := a.(type) { 295 {{ cases "-" }} 296 case time.Time: 297 switch y := b.(type) { 298 case time.Time: 299 return x.Sub(y) 300 case time.Duration: 301 return x.Add(-y) 302 } 303 case time.Duration: 304 switch y := b.(type) { 305 case time.Duration: 306 return x - y 307 } 308 } 309 panic(fmt.Sprintf("invalid operation: %T - %T", a, b)) 310 } 311 312 func Multiply(a, b interface{}) interface{} { 313 switch x := a.(type) { 314 {{ cases_with_duration "*" }} 315 } 316 panic(fmt.Sprintf("invalid operation: %T * %T", a, b)) 317 } 318 319 func Divide(a, b interface{}) float64 { 320 switch x := a.(type) { 321 {{ cases "/" }} 322 } 323 panic(fmt.Sprintf("invalid operation: %T / %T", a, b)) 324 } 325 326 func Modulo(a, b interface{}) int { 327 switch x := a.(type) { 328 {{ cases_int_only "%" }} 329 } 330 panic(fmt.Sprintf("invalid operation: %T %% %T", a, b)) 331 } 332 `