github.com/expr-lang/expr@v1.16.9/patcher/operator_override.go (about) 1 package patcher 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "github.com/expr-lang/expr/ast" 8 "github.com/expr-lang/expr/builtin" 9 "github.com/expr-lang/expr/conf" 10 ) 11 12 type OperatorOverloading struct { 13 Operator string // Operator token to overload. 14 Overloads []string // List of function names to replace operator with. 15 Types conf.TypesTable // Env types. 16 Functions conf.FunctionsTable // Env functions. 17 applied bool // Flag to indicate if any changes were made to the tree. 18 } 19 20 func (p *OperatorOverloading) Visit(node *ast.Node) { 21 binaryNode, ok := (*node).(*ast.BinaryNode) 22 if !ok { 23 return 24 } 25 26 if binaryNode.Operator != p.Operator { 27 return 28 } 29 30 leftType := binaryNode.Left.Type() 31 rightType := binaryNode.Right.Type() 32 33 ret, fn, ok := p.FindSuitableOperatorOverload(leftType, rightType) 34 if ok { 35 newNode := &ast.CallNode{ 36 Callee: &ast.IdentifierNode{Value: fn}, 37 Arguments: []ast.Node{binaryNode.Left, binaryNode.Right}, 38 } 39 newNode.SetType(ret) 40 ast.Patch(node, newNode) 41 p.applied = true 42 } 43 } 44 45 func (p *OperatorOverloading) ShouldRepeat() bool { 46 return p.applied 47 } 48 49 func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) { 50 t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r) 51 if !ok { 52 t, fn, ok = p.findSuitableOperatorOverloadInTypes(l, r) 53 } 54 return t, fn, ok 55 } 56 57 func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) { 58 for _, fn := range p.Overloads { 59 fnType, ok := p.Types[fn] 60 if !ok { 61 continue 62 } 63 firstInIndex := 0 64 if fnType.Method { 65 firstInIndex = 1 // As first argument to method is receiver. 66 } 67 ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex) 68 if done { 69 return ret, fn, true 70 } 71 } 72 return nil, "", false 73 } 74 75 func (p *OperatorOverloading) findSuitableOperatorOverloadInFunctions(l, r reflect.Type) (reflect.Type, string, bool) { 76 for _, fn := range p.Overloads { 77 fnType, ok := p.Functions[fn] 78 if !ok { 79 continue 80 } 81 firstInIndex := 0 82 for _, overload := range fnType.Types { 83 ret, done := checkTypeSuits(overload, l, r, firstInIndex) 84 if done { 85 return ret, fn, true 86 } 87 } 88 } 89 return nil, "", false 90 } 91 92 func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) { 93 firstArgType := t.In(firstInIndex) 94 secondArgType := t.In(firstInIndex + 1) 95 96 firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType))) 97 secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType))) 98 if firstArgumentFit && secondArgumentFit { 99 return t.Out(0), true 100 } 101 return nil, false 102 } 103 104 func (p *OperatorOverloading) Check() { 105 for _, fn := range p.Overloads { 106 fnType, foundType := p.Types[fn] 107 fnFunc, foundFunc := p.Functions[fn] 108 if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) { 109 panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator)) 110 } 111 112 if foundType { 113 checkType(fnType, fn, p.Operator) 114 } 115 116 if foundFunc { 117 checkFunc(fnFunc, fn, p.Operator) 118 } 119 } 120 } 121 122 func checkType(fnType conf.Tag, fn string, operator string) { 123 requiredNumIn := 2 124 if fnType.Method { 125 requiredNumIn = 3 // As first argument of method is receiver. 126 } 127 if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 { 128 panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator)) 129 } 130 } 131 132 func checkFunc(fn *builtin.Function, name string, operator string) { 133 if len(fn.Types) == 0 { 134 panic(fmt.Errorf("function %q for %q operator misses types", name, operator)) 135 } 136 for _, t := range fn.Types { 137 if t.NumIn() != 2 || t.NumOut() != 1 { 138 panic(fmt.Errorf("function %q for %q operator does not have a correct signature", name, operator)) 139 } 140 } 141 }