github.com/expr-lang/expr@v1.16.9/optimizer/fold.go (about)

     1  package optimizer
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"reflect"
     7  
     8  	. "github.com/expr-lang/expr/ast"
     9  	"github.com/expr-lang/expr/file"
    10  )
    11  
    12  var (
    13  	integerType = reflect.TypeOf(0)
    14  	floatType   = reflect.TypeOf(float64(0))
    15  	stringType  = reflect.TypeOf("")
    16  )
    17  
    18  type fold struct {
    19  	applied bool
    20  	err     *file.Error
    21  }
    22  
    23  func (fold *fold) Visit(node *Node) {
    24  	patch := func(newNode Node) {
    25  		fold.applied = true
    26  		Patch(node, newNode)
    27  	}
    28  	patchWithType := func(newNode Node) {
    29  		patch(newNode)
    30  		switch newNode.(type) {
    31  		case *IntegerNode:
    32  			newNode.SetType(integerType)
    33  		case *FloatNode:
    34  			newNode.SetType(floatType)
    35  		case *StringNode:
    36  			newNode.SetType(stringType)
    37  		default:
    38  			panic(fmt.Sprintf("unknown type %T", newNode))
    39  		}
    40  	}
    41  
    42  	switch n := (*node).(type) {
    43  	case *UnaryNode:
    44  		switch n.Operator {
    45  		case "-":
    46  			if i, ok := n.Node.(*IntegerNode); ok {
    47  				patchWithType(&IntegerNode{Value: -i.Value})
    48  			}
    49  			if i, ok := n.Node.(*FloatNode); ok {
    50  				patchWithType(&FloatNode{Value: -i.Value})
    51  			}
    52  		case "+":
    53  			if i, ok := n.Node.(*IntegerNode); ok {
    54  				patchWithType(&IntegerNode{Value: i.Value})
    55  			}
    56  			if i, ok := n.Node.(*FloatNode); ok {
    57  				patchWithType(&FloatNode{Value: i.Value})
    58  			}
    59  		case "!", "not":
    60  			if a := toBool(n.Node); a != nil {
    61  				patch(&BoolNode{Value: !a.Value})
    62  			}
    63  		}
    64  
    65  	case *BinaryNode:
    66  		switch n.Operator {
    67  		case "+":
    68  			{
    69  				a := toInteger(n.Left)
    70  				b := toInteger(n.Right)
    71  				if a != nil && b != nil {
    72  					patchWithType(&IntegerNode{Value: a.Value + b.Value})
    73  				}
    74  			}
    75  			{
    76  				a := toInteger(n.Left)
    77  				b := toFloat(n.Right)
    78  				if a != nil && b != nil {
    79  					patchWithType(&FloatNode{Value: float64(a.Value) + b.Value})
    80  				}
    81  			}
    82  			{
    83  				a := toFloat(n.Left)
    84  				b := toInteger(n.Right)
    85  				if a != nil && b != nil {
    86  					patchWithType(&FloatNode{Value: a.Value + float64(b.Value)})
    87  				}
    88  			}
    89  			{
    90  				a := toFloat(n.Left)
    91  				b := toFloat(n.Right)
    92  				if a != nil && b != nil {
    93  					patchWithType(&FloatNode{Value: a.Value + b.Value})
    94  				}
    95  			}
    96  			{
    97  				a := toString(n.Left)
    98  				b := toString(n.Right)
    99  				if a != nil && b != nil {
   100  					patch(&StringNode{Value: a.Value + b.Value})
   101  				}
   102  			}
   103  		case "-":
   104  			{
   105  				a := toInteger(n.Left)
   106  				b := toInteger(n.Right)
   107  				if a != nil && b != nil {
   108  					patchWithType(&IntegerNode{Value: a.Value - b.Value})
   109  				}
   110  			}
   111  			{
   112  				a := toInteger(n.Left)
   113  				b := toFloat(n.Right)
   114  				if a != nil && b != nil {
   115  					patchWithType(&FloatNode{Value: float64(a.Value) - b.Value})
   116  				}
   117  			}
   118  			{
   119  				a := toFloat(n.Left)
   120  				b := toInteger(n.Right)
   121  				if a != nil && b != nil {
   122  					patchWithType(&FloatNode{Value: a.Value - float64(b.Value)})
   123  				}
   124  			}
   125  			{
   126  				a := toFloat(n.Left)
   127  				b := toFloat(n.Right)
   128  				if a != nil && b != nil {
   129  					patchWithType(&FloatNode{Value: a.Value - b.Value})
   130  				}
   131  			}
   132  		case "*":
   133  			{
   134  				a := toInteger(n.Left)
   135  				b := toInteger(n.Right)
   136  				if a != nil && b != nil {
   137  					patchWithType(&IntegerNode{Value: a.Value * b.Value})
   138  				}
   139  			}
   140  			{
   141  				a := toInteger(n.Left)
   142  				b := toFloat(n.Right)
   143  				if a != nil && b != nil {
   144  					patchWithType(&FloatNode{Value: float64(a.Value) * b.Value})
   145  				}
   146  			}
   147  			{
   148  				a := toFloat(n.Left)
   149  				b := toInteger(n.Right)
   150  				if a != nil && b != nil {
   151  					patchWithType(&FloatNode{Value: a.Value * float64(b.Value)})
   152  				}
   153  			}
   154  			{
   155  				a := toFloat(n.Left)
   156  				b := toFloat(n.Right)
   157  				if a != nil && b != nil {
   158  					patchWithType(&FloatNode{Value: a.Value * b.Value})
   159  				}
   160  			}
   161  		case "/":
   162  			{
   163  				a := toInteger(n.Left)
   164  				b := toInteger(n.Right)
   165  				if a != nil && b != nil {
   166  					patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)})
   167  				}
   168  			}
   169  			{
   170  				a := toInteger(n.Left)
   171  				b := toFloat(n.Right)
   172  				if a != nil && b != nil {
   173  					patchWithType(&FloatNode{Value: float64(a.Value) / b.Value})
   174  				}
   175  			}
   176  			{
   177  				a := toFloat(n.Left)
   178  				b := toInteger(n.Right)
   179  				if a != nil && b != nil {
   180  					patchWithType(&FloatNode{Value: a.Value / float64(b.Value)})
   181  				}
   182  			}
   183  			{
   184  				a := toFloat(n.Left)
   185  				b := toFloat(n.Right)
   186  				if a != nil && b != nil {
   187  					patchWithType(&FloatNode{Value: a.Value / b.Value})
   188  				}
   189  			}
   190  		case "%":
   191  			if a, ok := n.Left.(*IntegerNode); ok {
   192  				if b, ok := n.Right.(*IntegerNode); ok {
   193  					if b.Value == 0 {
   194  						fold.err = &file.Error{
   195  							Location: (*node).Location(),
   196  							Message:  "integer divide by zero",
   197  						}
   198  						return
   199  					}
   200  					patch(&IntegerNode{Value: a.Value % b.Value})
   201  				}
   202  			}
   203  		case "**", "^":
   204  			{
   205  				a := toInteger(n.Left)
   206  				b := toInteger(n.Right)
   207  				if a != nil && b != nil {
   208  					patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))})
   209  				}
   210  			}
   211  			{
   212  				a := toInteger(n.Left)
   213  				b := toFloat(n.Right)
   214  				if a != nil && b != nil {
   215  					patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)})
   216  				}
   217  			}
   218  			{
   219  				a := toFloat(n.Left)
   220  				b := toInteger(n.Right)
   221  				if a != nil && b != nil {
   222  					patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))})
   223  				}
   224  			}
   225  			{
   226  				a := toFloat(n.Left)
   227  				b := toFloat(n.Right)
   228  				if a != nil && b != nil {
   229  					patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)})
   230  				}
   231  			}
   232  		case "and", "&&":
   233  			a := toBool(n.Left)
   234  			b := toBool(n.Right)
   235  
   236  			if a != nil && a.Value { // true and x
   237  				patch(n.Right)
   238  			} else if b != nil && b.Value { // x and true
   239  				patch(n.Left)
   240  			} else if (a != nil && !a.Value) || (b != nil && !b.Value) { // "x and false" or "false and x"
   241  				patch(&BoolNode{Value: false})
   242  			}
   243  		case "or", "||":
   244  			a := toBool(n.Left)
   245  			b := toBool(n.Right)
   246  
   247  			if a != nil && !a.Value { // false or x
   248  				patch(n.Right)
   249  			} else if b != nil && !b.Value { // x or false
   250  				patch(n.Left)
   251  			} else if (a != nil && a.Value) || (b != nil && b.Value) { // "x or true" or "true or x"
   252  				patch(&BoolNode{Value: true})
   253  			}
   254  		case "==":
   255  			{
   256  				a := toInteger(n.Left)
   257  				b := toInteger(n.Right)
   258  				if a != nil && b != nil {
   259  					patch(&BoolNode{Value: a.Value == b.Value})
   260  				}
   261  			}
   262  			{
   263  				a := toString(n.Left)
   264  				b := toString(n.Right)
   265  				if a != nil && b != nil {
   266  					patch(&BoolNode{Value: a.Value == b.Value})
   267  				}
   268  			}
   269  			{
   270  				a := toBool(n.Left)
   271  				b := toBool(n.Right)
   272  				if a != nil && b != nil {
   273  					patch(&BoolNode{Value: a.Value == b.Value})
   274  				}
   275  			}
   276  		}
   277  
   278  	case *ArrayNode:
   279  		if len(n.Nodes) > 0 {
   280  			for _, a := range n.Nodes {
   281  				switch a.(type) {
   282  				case *IntegerNode, *FloatNode, *StringNode, *BoolNode:
   283  					continue
   284  				default:
   285  					return
   286  				}
   287  			}
   288  			value := make([]any, len(n.Nodes))
   289  			for i, a := range n.Nodes {
   290  				switch b := a.(type) {
   291  				case *IntegerNode:
   292  					value[i] = b.Value
   293  				case *FloatNode:
   294  					value[i] = b.Value
   295  				case *StringNode:
   296  					value[i] = b.Value
   297  				case *BoolNode:
   298  					value[i] = b.Value
   299  				}
   300  			}
   301  			patch(&ConstantNode{Value: value})
   302  		}
   303  
   304  	case *BuiltinNode:
   305  		switch n.Name {
   306  		case "filter":
   307  			if len(n.Arguments) != 2 {
   308  				return
   309  			}
   310  			if base, ok := n.Arguments[0].(*BuiltinNode); ok && base.Name == "filter" {
   311  				patch(&BuiltinNode{
   312  					Name: "filter",
   313  					Arguments: []Node{
   314  						base.Arguments[0],
   315  						&BinaryNode{
   316  							Operator: "&&",
   317  							Left:     base.Arguments[1],
   318  							Right:    n.Arguments[1],
   319  						},
   320  					},
   321  				})
   322  			}
   323  		}
   324  	}
   325  }
   326  
   327  func toString(n Node) *StringNode {
   328  	switch a := n.(type) {
   329  	case *StringNode:
   330  		return a
   331  	}
   332  	return nil
   333  }
   334  
   335  func toInteger(n Node) *IntegerNode {
   336  	switch a := n.(type) {
   337  	case *IntegerNode:
   338  		return a
   339  	}
   340  	return nil
   341  }
   342  
   343  func toFloat(n Node) *FloatNode {
   344  	switch a := n.(type) {
   345  	case *FloatNode:
   346  		return a
   347  	}
   348  	return nil
   349  }
   350  
   351  func toBool(n Node) *BoolNode {
   352  	switch a := n.(type) {
   353  	case *BoolNode:
   354  		return a
   355  	}
   356  	return nil
   357  }