github.com/expr-lang/expr@v1.16.9/patcher/operator_override.go (about)

     1  package patcher
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/expr-lang/expr/ast"
     8  	"github.com/expr-lang/expr/builtin"
     9  	"github.com/expr-lang/expr/conf"
    10  )
    11  
    12  type OperatorOverloading struct {
    13  	Operator  string              // Operator token to overload.
    14  	Overloads []string            // List of function names to replace operator with.
    15  	Types     conf.TypesTable     // Env types.
    16  	Functions conf.FunctionsTable // Env functions.
    17  	applied   bool                // Flag to indicate if any changes were made to the tree.
    18  }
    19  
    20  func (p *OperatorOverloading) Visit(node *ast.Node) {
    21  	binaryNode, ok := (*node).(*ast.BinaryNode)
    22  	if !ok {
    23  		return
    24  	}
    25  
    26  	if binaryNode.Operator != p.Operator {
    27  		return
    28  	}
    29  
    30  	leftType := binaryNode.Left.Type()
    31  	rightType := binaryNode.Right.Type()
    32  
    33  	ret, fn, ok := p.FindSuitableOperatorOverload(leftType, rightType)
    34  	if ok {
    35  		newNode := &ast.CallNode{
    36  			Callee:    &ast.IdentifierNode{Value: fn},
    37  			Arguments: []ast.Node{binaryNode.Left, binaryNode.Right},
    38  		}
    39  		newNode.SetType(ret)
    40  		ast.Patch(node, newNode)
    41  		p.applied = true
    42  	}
    43  }
    44  
    45  func (p *OperatorOverloading) ShouldRepeat() bool {
    46  	return p.applied
    47  }
    48  
    49  func (p *OperatorOverloading) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) {
    50  	t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r)
    51  	if !ok {
    52  		t, fn, ok = p.findSuitableOperatorOverloadInTypes(l, r)
    53  	}
    54  	return t, fn, ok
    55  }
    56  
    57  func (p *OperatorOverloading) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) {
    58  	for _, fn := range p.Overloads {
    59  		fnType, ok := p.Types[fn]
    60  		if !ok {
    61  			continue
    62  		}
    63  		firstInIndex := 0
    64  		if fnType.Method {
    65  			firstInIndex = 1 // As first argument to method is receiver.
    66  		}
    67  		ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex)
    68  		if done {
    69  			return ret, fn, true
    70  		}
    71  	}
    72  	return nil, "", false
    73  }
    74  
    75  func (p *OperatorOverloading) findSuitableOperatorOverloadInFunctions(l, r reflect.Type) (reflect.Type, string, bool) {
    76  	for _, fn := range p.Overloads {
    77  		fnType, ok := p.Functions[fn]
    78  		if !ok {
    79  			continue
    80  		}
    81  		firstInIndex := 0
    82  		for _, overload := range fnType.Types {
    83  			ret, done := checkTypeSuits(overload, l, r, firstInIndex)
    84  			if done {
    85  				return ret, fn, true
    86  			}
    87  		}
    88  	}
    89  	return nil, "", false
    90  }
    91  
    92  func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
    93  	firstArgType := t.In(firstInIndex)
    94  	secondArgType := t.In(firstInIndex + 1)
    95  
    96  	firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
    97  	secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
    98  	if firstArgumentFit && secondArgumentFit {
    99  		return t.Out(0), true
   100  	}
   101  	return nil, false
   102  }
   103  
   104  func (p *OperatorOverloading) Check() {
   105  	for _, fn := range p.Overloads {
   106  		fnType, foundType := p.Types[fn]
   107  		fnFunc, foundFunc := p.Functions[fn]
   108  		if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
   109  			panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator))
   110  		}
   111  
   112  		if foundType {
   113  			checkType(fnType, fn, p.Operator)
   114  		}
   115  
   116  		if foundFunc {
   117  			checkFunc(fnFunc, fn, p.Operator)
   118  		}
   119  	}
   120  }
   121  
   122  func checkType(fnType conf.Tag, fn string, operator string) {
   123  	requiredNumIn := 2
   124  	if fnType.Method {
   125  		requiredNumIn = 3 // As first argument of method is receiver.
   126  	}
   127  	if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
   128  		panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
   129  	}
   130  }
   131  
   132  func checkFunc(fn *builtin.Function, name string, operator string) {
   133  	if len(fn.Types) == 0 {
   134  		panic(fmt.Errorf("function %q for %q operator misses types", name, operator))
   135  	}
   136  	for _, t := range fn.Types {
   137  		if t.NumIn() != 2 || t.NumOut() != 1 {
   138  			panic(fmt.Errorf("function %q for %q operator does not have a correct signature", name, operator))
   139  		}
   140  	}
   141  }