github.com/expr-lang/expr@v1.16.9/vm/runtime/helpers/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/format"
     7  	"os"
     8  	"strings"
     9  	"text/template"
    10  )
    11  
    12  func main() {
    13  	var b bytes.Buffer
    14  	err := template.Must(
    15  		template.New("helpers").
    16  			Funcs(template.FuncMap{
    17  				"cases":          func(op string) string { return cases(op, uints, ints, floats) },
    18  				"cases_int_only": func(op string) string { return cases(op, uints, ints) },
    19  				"cases_with_duration": func(op string) string {
    20  					return cases(op, uints, ints, floats, []string{"time.Duration"})
    21  				},
    22  				"array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) },
    23  			}).
    24  			Parse(helpers),
    25  	).Execute(&b, nil)
    26  	if err != nil {
    27  		panic(err)
    28  	}
    29  
    30  	formatted, err := format.Source(b.Bytes())
    31  	if err != nil {
    32  		panic(err)
    33  	}
    34  	fmt.Print(string(formatted))
    35  }
    36  
    37  var ints = []string{
    38  	"int",
    39  	"int8",
    40  	"int16",
    41  	"int32",
    42  	"int64",
    43  }
    44  
    45  var uints = []string{
    46  	"uint",
    47  	"uint8",
    48  	"uint16",
    49  	"uint32",
    50  	"uint64",
    51  }
    52  
    53  var floats = []string{
    54  	"float32",
    55  	"float64",
    56  }
    57  
    58  func cases(op string, xs ...[]string) string {
    59  	var types []string
    60  	for _, x := range xs {
    61  		types = append(types, x...)
    62  	}
    63  
    64  	_, _ = fmt.Fprintf(os.Stderr, "Generating %s cases for %v\n", op, types)
    65  
    66  	var out string
    67  	echo := func(s string, xs ...any) {
    68  		out += fmt.Sprintf(s, xs...) + "\n"
    69  	}
    70  	for _, a := range types {
    71  		echo(`case %v:`, a)
    72  		echo(`switch y := b.(type) {`)
    73  		for _, b := range types {
    74  			t := "int"
    75  			if isDuration(a) || isDuration(b) {
    76  				t = "time.Duration"
    77  			}
    78  			if isFloat(a) || isFloat(b) {
    79  				t = "float64"
    80  			}
    81  			echo(`case %v:`, b)
    82  			if op == "/" {
    83  				echo(`return float64(x) / float64(y)`)
    84  			} else {
    85  				echo(`return %v(x) %v %v(y)`, t, op, t)
    86  			}
    87  		}
    88  		echo(`}`)
    89  	}
    90  	return strings.TrimRight(out, "\n")
    91  }
    92  
    93  func arrayEqualCases(xs ...[]string) string {
    94  	var types []string
    95  	for _, x := range xs {
    96  		types = append(types, x...)
    97  	}
    98  
    99  	_, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types)
   100  
   101  	var out string
   102  	echo := func(s string, xs ...any) {
   103  		out += fmt.Sprintf(s, xs...) + "\n"
   104  	}
   105  	echo(`case []any:`)
   106  	echo(`switch y := b.(type) {`)
   107  	for _, a := range append(types, "any") {
   108  		echo(`case []%v:`, a)
   109  		echo(`if len(x) != len(y) { return false }`)
   110  		echo(`for i := range x {`)
   111  		echo(`if !Equal(x[i], y[i]) { return false }`)
   112  		echo(`}`)
   113  		echo("return true")
   114  	}
   115  	echo(`}`)
   116  	for _, a := range types {
   117  		echo(`case []%v:`, a)
   118  		echo(`switch y := b.(type) {`)
   119  		echo(`case []any:`)
   120  		echo(`return Equal(y, x)`)
   121  		echo(`case []%v:`, a)
   122  		echo(`if len(x) != len(y) { return false }`)
   123  		echo(`for i := range x {`)
   124  		echo(`if x[i] != y[i] { return false }`)
   125  		echo(`}`)
   126  		echo("return true")
   127  		echo(`}`)
   128  	}
   129  	return strings.TrimRight(out, "\n")
   130  }
   131  
   132  func isFloat(t string) bool {
   133  	return strings.HasPrefix(t, "float")
   134  }
   135  
   136  func isDuration(t string) bool {
   137  	return t == "time.Duration"
   138  }
   139  
   140  const helpers = `// Code generated by vm/runtime/helpers/main.go. DO NOT EDIT.
   141  
   142  package runtime
   143  
   144  import (
   145  	"fmt"
   146  	"reflect"
   147  	"time"
   148  )
   149  
   150  func Equal(a, b interface{}) bool {
   151  	switch x := a.(type) {
   152  	{{ cases "==" }}
   153  	{{ array_equal_cases }}
   154  	case string:
   155  		switch y := b.(type) {
   156  		case string:
   157  			return x == y
   158  		}
   159  	case time.Time:
   160  		switch y := b.(type) {
   161  		case time.Time:
   162  			return x.Equal(y)
   163  		}
   164  	case time.Duration:
   165  		switch y := b.(type) {
   166  		case time.Duration:
   167  			return x == y
   168  		}
   169  	case bool:
   170  		switch y := b.(type) {
   171  		case bool:
   172  			return x == y
   173  		}
   174  	}
   175  	if IsNil(a) && IsNil(b) {
   176  		return true
   177  	}
   178  	return reflect.DeepEqual(a, b)
   179  }
   180  
   181  func Less(a, b interface{}) bool {
   182  	switch x := a.(type) {
   183  	{{ cases "<" }}
   184  	case string:
   185  		switch y := b.(type) {
   186  		case string:
   187  			return x < y
   188  		}
   189  	case time.Time:
   190  		switch y := b.(type) {
   191  		case time.Time:
   192  			return x.Before(y)
   193  		}
   194  	case time.Duration:
   195  		switch y := b.(type) {
   196  		case time.Duration:
   197  			return x < y
   198  		}
   199  	}
   200  	panic(fmt.Sprintf("invalid operation: %T < %T", a, b))
   201  }
   202  
   203  func More(a, b interface{}) bool {
   204  	switch x := a.(type) {
   205  	{{ cases ">" }}
   206  	case string:
   207  		switch y := b.(type) {
   208  		case string:
   209  			return x > y
   210  		}
   211  	case time.Time:
   212  		switch y := b.(type) {
   213  		case time.Time:
   214  			return x.After(y)
   215  		}
   216  	case time.Duration:
   217  		switch y := b.(type) {
   218  		case time.Duration:
   219  			return x > y
   220  		}
   221  	}
   222  	panic(fmt.Sprintf("invalid operation: %T > %T", a, b))
   223  }
   224  
   225  func LessOrEqual(a, b interface{}) bool {
   226  	switch x := a.(type) {
   227  	{{ cases "<=" }}
   228  	case string:
   229  		switch y := b.(type) {
   230  		case string:
   231  			return x <= y
   232  		}
   233  	case time.Time:
   234  		switch y := b.(type) {
   235  		case time.Time:
   236  			return x.Before(y) || x.Equal(y)
   237  		}
   238  	case time.Duration:
   239  		switch y := b.(type) {
   240  		case time.Duration:
   241  			return x <= y
   242  		}
   243  	}
   244  	panic(fmt.Sprintf("invalid operation: %T <= %T", a, b))
   245  }
   246  
   247  func MoreOrEqual(a, b interface{}) bool {
   248  	switch x := a.(type) {
   249  	{{ cases ">=" }}
   250  	case string:
   251  		switch y := b.(type) {
   252  		case string:
   253  			return x >= y
   254  		}
   255  	case time.Time:
   256  		switch y := b.(type) {
   257  		case time.Time:
   258  			return x.After(y) || x.Equal(y)
   259  		}
   260  	case time.Duration:
   261  		switch y := b.(type) {
   262  		case time.Duration:
   263  			return x >= y
   264  		}
   265  	}
   266  	panic(fmt.Sprintf("invalid operation: %T >= %T", a, b))
   267  }
   268  
   269  func Add(a, b interface{}) interface{} {
   270  	switch x := a.(type) {
   271  	{{ cases "+" }}
   272  	case string:
   273  		switch y := b.(type) {
   274  		case string:
   275  			return x + y
   276  		}
   277  	case time.Time:
   278  		switch y := b.(type) {
   279  		case time.Duration:
   280  			return x.Add(y)
   281  		}
   282  	case time.Duration:
   283  		switch y := b.(type) {
   284  		case time.Time:
   285  			return y.Add(x)
   286  		case time.Duration:
   287  			return x + y
   288  		}
   289  	}
   290  	panic(fmt.Sprintf("invalid operation: %T + %T", a, b))
   291  }
   292  
   293  func Subtract(a, b interface{}) interface{} {
   294  	switch x := a.(type) {
   295  	{{ cases "-" }}
   296  	case time.Time:
   297  		switch y := b.(type) {
   298  		case time.Time:
   299  			return x.Sub(y)
   300  		case time.Duration:
   301  			return x.Add(-y)
   302  		}
   303  	case time.Duration:
   304  		switch y := b.(type) {
   305  		case time.Duration:
   306  			return x - y
   307  		}
   308  	}
   309  	panic(fmt.Sprintf("invalid operation: %T - %T", a, b))
   310  }
   311  
   312  func Multiply(a, b interface{}) interface{} {
   313  	switch x := a.(type) {
   314  	{{ cases_with_duration "*" }}
   315  	}
   316  	panic(fmt.Sprintf("invalid operation: %T * %T", a, b))
   317  }
   318  
   319  func Divide(a, b interface{}) float64 {
   320  	switch x := a.(type) {
   321  	{{ cases "/" }}
   322  	}
   323  	panic(fmt.Sprintf("invalid operation: %T / %T", a, b))
   324  }
   325  
   326  func Modulo(a, b interface{}) int {
   327  	switch x := a.(type) {
   328  	{{ cases_int_only "%" }}
   329  	}
   330  	panic(fmt.Sprintf("invalid operation: %T %% %T", a, b))
   331  }
   332  `