github.com/expr-lang/expr@v1.16.9/optimizer/fold.go (about) 1 package optimizer 2 3 import ( 4 "fmt" 5 "math" 6 "reflect" 7 8 . "github.com/expr-lang/expr/ast" 9 "github.com/expr-lang/expr/file" 10 ) 11 12 var ( 13 integerType = reflect.TypeOf(0) 14 floatType = reflect.TypeOf(float64(0)) 15 stringType = reflect.TypeOf("") 16 ) 17 18 type fold struct { 19 applied bool 20 err *file.Error 21 } 22 23 func (fold *fold) Visit(node *Node) { 24 patch := func(newNode Node) { 25 fold.applied = true 26 Patch(node, newNode) 27 } 28 patchWithType := func(newNode Node) { 29 patch(newNode) 30 switch newNode.(type) { 31 case *IntegerNode: 32 newNode.SetType(integerType) 33 case *FloatNode: 34 newNode.SetType(floatType) 35 case *StringNode: 36 newNode.SetType(stringType) 37 default: 38 panic(fmt.Sprintf("unknown type %T", newNode)) 39 } 40 } 41 42 switch n := (*node).(type) { 43 case *UnaryNode: 44 switch n.Operator { 45 case "-": 46 if i, ok := n.Node.(*IntegerNode); ok { 47 patchWithType(&IntegerNode{Value: -i.Value}) 48 } 49 if i, ok := n.Node.(*FloatNode); ok { 50 patchWithType(&FloatNode{Value: -i.Value}) 51 } 52 case "+": 53 if i, ok := n.Node.(*IntegerNode); ok { 54 patchWithType(&IntegerNode{Value: i.Value}) 55 } 56 if i, ok := n.Node.(*FloatNode); ok { 57 patchWithType(&FloatNode{Value: i.Value}) 58 } 59 case "!", "not": 60 if a := toBool(n.Node); a != nil { 61 patch(&BoolNode{Value: !a.Value}) 62 } 63 } 64 65 case *BinaryNode: 66 switch n.Operator { 67 case "+": 68 { 69 a := toInteger(n.Left) 70 b := toInteger(n.Right) 71 if a != nil && b != nil { 72 patchWithType(&IntegerNode{Value: a.Value + b.Value}) 73 } 74 } 75 { 76 a := toInteger(n.Left) 77 b := toFloat(n.Right) 78 if a != nil && b != nil { 79 patchWithType(&FloatNode{Value: float64(a.Value) + b.Value}) 80 } 81 } 82 { 83 a := toFloat(n.Left) 84 b := toInteger(n.Right) 85 if a != nil && b != nil { 86 patchWithType(&FloatNode{Value: a.Value + float64(b.Value)}) 87 } 88 } 89 { 90 a := toFloat(n.Left) 91 b := toFloat(n.Right) 92 if a != nil && b != nil { 93 patchWithType(&FloatNode{Value: a.Value + b.Value}) 94 } 95 } 96 { 97 a := toString(n.Left) 98 b := toString(n.Right) 99 if a != nil && b != nil { 100 patch(&StringNode{Value: a.Value + b.Value}) 101 } 102 } 103 case "-": 104 { 105 a := toInteger(n.Left) 106 b := toInteger(n.Right) 107 if a != nil && b != nil { 108 patchWithType(&IntegerNode{Value: a.Value - b.Value}) 109 } 110 } 111 { 112 a := toInteger(n.Left) 113 b := toFloat(n.Right) 114 if a != nil && b != nil { 115 patchWithType(&FloatNode{Value: float64(a.Value) - b.Value}) 116 } 117 } 118 { 119 a := toFloat(n.Left) 120 b := toInteger(n.Right) 121 if a != nil && b != nil { 122 patchWithType(&FloatNode{Value: a.Value - float64(b.Value)}) 123 } 124 } 125 { 126 a := toFloat(n.Left) 127 b := toFloat(n.Right) 128 if a != nil && b != nil { 129 patchWithType(&FloatNode{Value: a.Value - b.Value}) 130 } 131 } 132 case "*": 133 { 134 a := toInteger(n.Left) 135 b := toInteger(n.Right) 136 if a != nil && b != nil { 137 patchWithType(&IntegerNode{Value: a.Value * b.Value}) 138 } 139 } 140 { 141 a := toInteger(n.Left) 142 b := toFloat(n.Right) 143 if a != nil && b != nil { 144 patchWithType(&FloatNode{Value: float64(a.Value) * b.Value}) 145 } 146 } 147 { 148 a := toFloat(n.Left) 149 b := toInteger(n.Right) 150 if a != nil && b != nil { 151 patchWithType(&FloatNode{Value: a.Value * float64(b.Value)}) 152 } 153 } 154 { 155 a := toFloat(n.Left) 156 b := toFloat(n.Right) 157 if a != nil && b != nil { 158 patchWithType(&FloatNode{Value: a.Value * b.Value}) 159 } 160 } 161 case "/": 162 { 163 a := toInteger(n.Left) 164 b := toInteger(n.Right) 165 if a != nil && b != nil { 166 patchWithType(&FloatNode{Value: float64(a.Value) / float64(b.Value)}) 167 } 168 } 169 { 170 a := toInteger(n.Left) 171 b := toFloat(n.Right) 172 if a != nil && b != nil { 173 patchWithType(&FloatNode{Value: float64(a.Value) / b.Value}) 174 } 175 } 176 { 177 a := toFloat(n.Left) 178 b := toInteger(n.Right) 179 if a != nil && b != nil { 180 patchWithType(&FloatNode{Value: a.Value / float64(b.Value)}) 181 } 182 } 183 { 184 a := toFloat(n.Left) 185 b := toFloat(n.Right) 186 if a != nil && b != nil { 187 patchWithType(&FloatNode{Value: a.Value / b.Value}) 188 } 189 } 190 case "%": 191 if a, ok := n.Left.(*IntegerNode); ok { 192 if b, ok := n.Right.(*IntegerNode); ok { 193 if b.Value == 0 { 194 fold.err = &file.Error{ 195 Location: (*node).Location(), 196 Message: "integer divide by zero", 197 } 198 return 199 } 200 patch(&IntegerNode{Value: a.Value % b.Value}) 201 } 202 } 203 case "**", "^": 204 { 205 a := toInteger(n.Left) 206 b := toInteger(n.Right) 207 if a != nil && b != nil { 208 patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), float64(b.Value))}) 209 } 210 } 211 { 212 a := toInteger(n.Left) 213 b := toFloat(n.Right) 214 if a != nil && b != nil { 215 patchWithType(&FloatNode{Value: math.Pow(float64(a.Value), b.Value)}) 216 } 217 } 218 { 219 a := toFloat(n.Left) 220 b := toInteger(n.Right) 221 if a != nil && b != nil { 222 patchWithType(&FloatNode{Value: math.Pow(a.Value, float64(b.Value))}) 223 } 224 } 225 { 226 a := toFloat(n.Left) 227 b := toFloat(n.Right) 228 if a != nil && b != nil { 229 patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}) 230 } 231 } 232 case "and", "&&": 233 a := toBool(n.Left) 234 b := toBool(n.Right) 235 236 if a != nil && a.Value { // true and x 237 patch(n.Right) 238 } else if b != nil && b.Value { // x and true 239 patch(n.Left) 240 } else if (a != nil && !a.Value) || (b != nil && !b.Value) { // "x and false" or "false and x" 241 patch(&BoolNode{Value: false}) 242 } 243 case "or", "||": 244 a := toBool(n.Left) 245 b := toBool(n.Right) 246 247 if a != nil && !a.Value { // false or x 248 patch(n.Right) 249 } else if b != nil && !b.Value { // x or false 250 patch(n.Left) 251 } else if (a != nil && a.Value) || (b != nil && b.Value) { // "x or true" or "true or x" 252 patch(&BoolNode{Value: true}) 253 } 254 case "==": 255 { 256 a := toInteger(n.Left) 257 b := toInteger(n.Right) 258 if a != nil && b != nil { 259 patch(&BoolNode{Value: a.Value == b.Value}) 260 } 261 } 262 { 263 a := toString(n.Left) 264 b := toString(n.Right) 265 if a != nil && b != nil { 266 patch(&BoolNode{Value: a.Value == b.Value}) 267 } 268 } 269 { 270 a := toBool(n.Left) 271 b := toBool(n.Right) 272 if a != nil && b != nil { 273 patch(&BoolNode{Value: a.Value == b.Value}) 274 } 275 } 276 } 277 278 case *ArrayNode: 279 if len(n.Nodes) > 0 { 280 for _, a := range n.Nodes { 281 switch a.(type) { 282 case *IntegerNode, *FloatNode, *StringNode, *BoolNode: 283 continue 284 default: 285 return 286 } 287 } 288 value := make([]any, len(n.Nodes)) 289 for i, a := range n.Nodes { 290 switch b := a.(type) { 291 case *IntegerNode: 292 value[i] = b.Value 293 case *FloatNode: 294 value[i] = b.Value 295 case *StringNode: 296 value[i] = b.Value 297 case *BoolNode: 298 value[i] = b.Value 299 } 300 } 301 patch(&ConstantNode{Value: value}) 302 } 303 304 case *BuiltinNode: 305 switch n.Name { 306 case "filter": 307 if len(n.Arguments) != 2 { 308 return 309 } 310 if base, ok := n.Arguments[0].(*BuiltinNode); ok && base.Name == "filter" { 311 patch(&BuiltinNode{ 312 Name: "filter", 313 Arguments: []Node{ 314 base.Arguments[0], 315 &BinaryNode{ 316 Operator: "&&", 317 Left: base.Arguments[1], 318 Right: n.Arguments[1], 319 }, 320 }, 321 }) 322 } 323 } 324 } 325 } 326 327 func toString(n Node) *StringNode { 328 switch a := n.(type) { 329 case *StringNode: 330 return a 331 } 332 return nil 333 } 334 335 func toInteger(n Node) *IntegerNode { 336 switch a := n.(type) { 337 case *IntegerNode: 338 return a 339 } 340 return nil 341 } 342 343 func toFloat(n Node) *FloatNode { 344 switch a := n.(type) { 345 case *FloatNode: 346 return a 347 } 348 return nil 349 } 350 351 func toBool(n Node) *BoolNode { 352 switch a := n.(type) { 353 case *BoolNode: 354 return a 355 } 356 return nil 357 }