github.com/richardwilkes/toolbox@v1.121.0/eval/evaluator.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  package eval
    11  
    12  import (
    13  	"strings"
    14  
    15  	"github.com/richardwilkes/toolbox/errs"
    16  )
    17  
    18  // VariableResolver is used to resolve variables in expressions into their values.
    19  type VariableResolver interface {
    20  	ResolveVariable(variableName string) string
    21  }
    22  
    23  type expressionOperand struct {
    24  	unaryOp *Operator
    25  	value   string
    26  }
    27  
    28  type expressionOperator struct {
    29  	op      *Operator
    30  	unaryOp *Operator
    31  }
    32  
    33  type expressionTree struct {
    34  	evaluator *Evaluator
    35  	left      any
    36  	right     any
    37  	op        *Operator
    38  	unaryOp   *Operator
    39  }
    40  
    41  // Function provides a signature for a Function.
    42  type Function func(evaluator *Evaluator, arguments string) (any, error)
    43  
    44  type parsedFunction struct {
    45  	unaryOp  *Operator
    46  	function Function
    47  	args     string
    48  }
    49  
    50  // Evaluator is used to evaluate an expression. If you do not have any variables that will be resolved, you can leave
    51  // Resolver unset. StdOperators() and StdFunctions() can be used to populate the Operators and Functions fields.
    52  type Evaluator struct {
    53  	Resolver      VariableResolver
    54  	Operators     []*Operator
    55  	Functions     map[string]Function
    56  	operandStack  []any
    57  	operatorStack []*expressionOperator
    58  }
    59  
    60  // Evaluate an expression.
    61  func (e *Evaluator) Evaluate(expression string) (any, error) {
    62  	if err := e.parse(expression); err != nil {
    63  		return nil, err
    64  	}
    65  	for len(e.operatorStack) != 0 {
    66  		e.processTree()
    67  	}
    68  	if len(e.operandStack) == 0 {
    69  		return "", nil
    70  	}
    71  	return e.evaluateOperand(e.operandStack[len(e.operandStack)-1])
    72  }
    73  
    74  // EvaluateNew reuses the Resolver, Operators, and Functions from this Evaluator to create a new Evaluator and then
    75  // resolves an expression with it.
    76  func (e *Evaluator) EvaluateNew(expression string) (any, error) {
    77  	other := Evaluator{
    78  		Resolver:  e.Resolver,
    79  		Operators: e.Operators,
    80  		Functions: e.Functions,
    81  	}
    82  	return other.Evaluate(expression)
    83  }
    84  
    85  func (e *Evaluator) parse(expression string) error {
    86  	var unaryOp *Operator
    87  	haveOperand := false
    88  	e.operandStack = nil
    89  	e.operatorStack = nil
    90  	i := 0
    91  	for i < len(expression) {
    92  		ch := expression[i]
    93  		if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
    94  			i++
    95  			continue
    96  		}
    97  		opIndex, op := e.nextOperator(expression, i, nil)
    98  		if opIndex > i || opIndex == -1 {
    99  			var err error
   100  			if i, err = e.processOperand(expression, i, opIndex, unaryOp); err != nil {
   101  				return err
   102  			}
   103  			haveOperand = true
   104  			unaryOp = nil
   105  		}
   106  		if opIndex == i {
   107  			if op != nil && op.EvaluateUnary != nil && i == 0 {
   108  				i = opIndex + len(op.Symbol)
   109  				if unaryOp != nil {
   110  					return errs.Newf("consecutive unary operators are not allowed at index %d", i)
   111  				}
   112  				unaryOp = op
   113  			} else {
   114  				var err error
   115  				if i, err = e.processOperator(expression, opIndex, op, haveOperand, unaryOp); err != nil {
   116  					return err
   117  				}
   118  				unaryOp = nil
   119  			}
   120  			if op == nil || op.Symbol != ")" {
   121  				haveOperand = false
   122  			}
   123  		}
   124  	}
   125  	return nil
   126  }
   127  
   128  func (e *Evaluator) nextOperator(expression string, start int, match *Operator) (int, *Operator) {
   129  	for i := start; i < len(expression); i++ {
   130  		if match != nil {
   131  			if match.match(expression, i, len(expression)) {
   132  				return i, match
   133  			}
   134  		} else {
   135  			for _, op := range e.Operators {
   136  				if op.match(expression, i, len(expression)) {
   137  					return i, op
   138  				}
   139  			}
   140  		}
   141  	}
   142  	return -1, nil
   143  }
   144  
   145  func (e *Evaluator) processOperand(expression string, start, opIndex int, unaryOp *Operator) (int, error) {
   146  	if opIndex == -1 {
   147  		text := strings.TrimSpace(expression[start:])
   148  		if text == "" {
   149  			return -1, errs.Newf("expression is invalid at index %d", start)
   150  		}
   151  		e.operandStack = append(e.operandStack, &expressionOperand{
   152  			value:   text,
   153  			unaryOp: unaryOp,
   154  		})
   155  		return len(expression), nil
   156  	}
   157  	text := strings.TrimSpace(expression[start:opIndex])
   158  	if text == "" {
   159  		return -1, errs.Newf("expression is invalid at index %d", start)
   160  	}
   161  	e.operandStack = append(e.operandStack, &expressionOperand{
   162  		value:   text,
   163  		unaryOp: unaryOp,
   164  	})
   165  	return opIndex, nil
   166  }
   167  
   168  func (e *Evaluator) processOperator(expression string, index int, op *Operator, haveOperand bool, unaryOp *Operator) (int, error) {
   169  	if haveOperand && op != nil && op.Symbol == "(" {
   170  		var err error
   171  		index, op, err = e.processFunction(expression, index)
   172  		if err != nil {
   173  			return -1, err
   174  		}
   175  		index += len(op.Symbol)
   176  		var tmp int
   177  		tmp, op = e.nextOperator(expression, index, nil)
   178  		if op == nil {
   179  			return index, nil
   180  		}
   181  		index = tmp
   182  	}
   183  	switch op.Symbol {
   184  	case "(":
   185  		e.operatorStack = append(e.operatorStack, &expressionOperator{
   186  			op:      op,
   187  			unaryOp: unaryOp,
   188  		})
   189  	case ")":
   190  		var stackOp *expressionOperator
   191  		if len(e.operatorStack) > 0 {
   192  			stackOp = e.operatorStack[len(e.operatorStack)-1]
   193  		}
   194  		for stackOp != nil && stackOp.op.Symbol != "(" {
   195  			e.processTree()
   196  			if len(e.operatorStack) > 0 {
   197  				stackOp = e.operatorStack[len(e.operatorStack)-1]
   198  			} else {
   199  				stackOp = nil
   200  			}
   201  		}
   202  		if len(e.operatorStack) == 0 {
   203  			return -1, errs.Newf("invalid expression at index %d", index)
   204  		}
   205  		stackOp = e.operatorStack[len(e.operatorStack)-1]
   206  		if stackOp.op.Symbol != "(" {
   207  			return -1, errs.Newf("invalid expression at index %d", index)
   208  		}
   209  		e.operatorStack = e.operatorStack[:len(e.operatorStack)-1]
   210  		if stackOp.unaryOp != nil {
   211  			left := e.operandStack[len(e.operandStack)-1]
   212  			e.operandStack = e.operandStack[:len(e.operandStack)-1]
   213  			e.operandStack = append(e.operandStack, &expressionTree{
   214  				evaluator: e,
   215  				left:      left,
   216  				unaryOp:   stackOp.unaryOp,
   217  			})
   218  		}
   219  	default:
   220  		if len(e.operatorStack) > 0 {
   221  			stackOp := e.operatorStack[len(e.operatorStack)-1]
   222  			for stackOp != nil && stackOp.op.Precedence >= op.Precedence {
   223  				e.processTree()
   224  				if len(e.operatorStack) > 0 {
   225  					stackOp = e.operatorStack[len(e.operatorStack)-1]
   226  				} else {
   227  					stackOp = nil
   228  				}
   229  			}
   230  		}
   231  		e.operatorStack = append(e.operatorStack, &expressionOperator{
   232  			op:      op,
   233  			unaryOp: unaryOp,
   234  		})
   235  	}
   236  	return index + len(op.Symbol), nil
   237  }
   238  
   239  func (e *Evaluator) processFunction(expression string, opIndex int) (int, *Operator, error) {
   240  	var op *Operator
   241  	parens := 1
   242  	next := opIndex
   243  	for parens > 0 {
   244  		if next, op = e.nextOperator(expression, next+1, nil); op == nil {
   245  			return -1, nil, errs.Newf("function not closed at index %d", opIndex)
   246  		}
   247  		switch op.Symbol {
   248  		case "(":
   249  			parens++
   250  		case ")":
   251  			parens--
   252  		default:
   253  		}
   254  	}
   255  	if len(e.operandStack) == 0 {
   256  		return -1, nil, errs.Newf("invalid stack at index %d", next)
   257  	}
   258  	operand, ok := e.operandStack[len(e.operandStack)-1].(*expressionOperand)
   259  	if !ok {
   260  		return -1, nil, errs.Newf("unexpected operand stack value at index %d", next)
   261  	}
   262  	e.operandStack = e.operandStack[:len(e.operandStack)-1]
   263  	f, exists := e.Functions[operand.value]
   264  	if !exists {
   265  		return -1, nil, errs.Newf("function not defined: %s", operand.value)
   266  	}
   267  	e.operandStack = append(e.operandStack, &parsedFunction{
   268  		function: f,
   269  		args:     expression[opIndex+1 : next],
   270  		unaryOp:  operand.unaryOp,
   271  	})
   272  	return next, op, nil
   273  }
   274  
   275  func (e *Evaluator) processTree() {
   276  	var right any
   277  	if len(e.operandStack) > 0 {
   278  		right = e.operandStack[len(e.operandStack)-1]
   279  		e.operandStack = e.operandStack[:len(e.operandStack)-1]
   280  	}
   281  	var left any
   282  	if len(e.operandStack) > 0 {
   283  		left = e.operandStack[len(e.operandStack)-1]
   284  		e.operandStack = e.operandStack[:len(e.operandStack)-1]
   285  	}
   286  	op := e.operatorStack[len(e.operatorStack)-1]
   287  	e.operatorStack = e.operatorStack[:len(e.operatorStack)-1]
   288  	e.operandStack = append(e.operandStack, &expressionTree{
   289  		evaluator: e,
   290  		left:      left,
   291  		right:     right,
   292  		op:        op.op,
   293  	})
   294  }
   295  
   296  func (e *Evaluator) evaluateOperand(operand any) (any, error) {
   297  	switch op := operand.(type) {
   298  	case *expressionTree:
   299  		left, err := op.evaluator.evaluateOperand(op.left)
   300  		if err != nil {
   301  			return nil, err
   302  		}
   303  		var right any
   304  		right, err = op.evaluator.evaluateOperand(op.right)
   305  		if err != nil {
   306  			return nil, err
   307  		}
   308  		if op.left != nil && op.right != nil {
   309  			if op.op.Evaluate == nil {
   310  				return nil, errs.New("operator does not have Evaluate function defined")
   311  			}
   312  			var v any
   313  			v, err = op.op.Evaluate(left, right)
   314  			if err != nil {
   315  				return nil, err
   316  			}
   317  			if op.unaryOp != nil && op.unaryOp.EvaluateUnary != nil {
   318  				return op.unaryOp.EvaluateUnary(v)
   319  			}
   320  			return v, nil
   321  		}
   322  		var v any
   323  		if op.right == nil {
   324  			v = left
   325  		} else {
   326  			v = right
   327  		}
   328  		if v != nil {
   329  			if op.unaryOp != nil && op.unaryOp.EvaluateUnary != nil {
   330  				v, err = op.unaryOp.EvaluateUnary(v)
   331  			} else if op.op != nil && op.op.EvaluateUnary != nil {
   332  				v, err = op.op.EvaluateUnary(v)
   333  			}
   334  			if err != nil {
   335  				return nil, err
   336  			}
   337  		}
   338  		if v == nil {
   339  			return nil, errs.New("expression is invalid")
   340  		}
   341  		return v, nil
   342  	case *expressionOperand:
   343  		v, err := e.replaceVariables(op.value)
   344  		if err != nil {
   345  			return nil, err
   346  		}
   347  		if op.unaryOp != nil && op.unaryOp.EvaluateUnary != nil {
   348  			return op.unaryOp.EvaluateUnary(v)
   349  		}
   350  		return v, nil
   351  	case *parsedFunction:
   352  		s, err := e.replaceVariables(op.args)
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  		var v any
   357  		v, err = op.function(e, s)
   358  		if err != nil {
   359  			return nil, err
   360  		}
   361  		if op.unaryOp != nil && op.unaryOp.EvaluateUnary != nil {
   362  			return op.unaryOp.EvaluateUnary(v)
   363  		}
   364  		return v, nil
   365  	default:
   366  		if op != nil {
   367  			return nil, errs.New("invalid expression")
   368  		}
   369  		return nil, nil
   370  	}
   371  }
   372  
   373  func (e *Evaluator) replaceVariables(expression string) (string, error) {
   374  	dollar := strings.IndexRune(expression, '$')
   375  	if dollar == -1 {
   376  		return expression, nil
   377  	}
   378  	if e.Resolver == nil {
   379  		return "", errs.Newf("no variable resolver, yet variables present at index %d", dollar)
   380  	}
   381  	for dollar >= 0 {
   382  		last := dollar
   383  		for i, ch := range expression[dollar+1:] {
   384  			if ch == '_' || ch == '.' || ch == '#' || (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || (i != 0 && ch >= '0' && ch <= '9') {
   385  				last = dollar + 1 + i
   386  			} else {
   387  				break
   388  			}
   389  		}
   390  		if dollar == last {
   391  			return "", errs.Newf("invalid variable at index %d", dollar)
   392  		}
   393  		name := expression[dollar+1 : last+1]
   394  		v := e.Resolver.ResolveVariable(name)
   395  		if strings.TrimSpace(v) == "" {
   396  			return "", errs.Newf("unable to resolve variable $%s", name)
   397  		}
   398  		var buffer strings.Builder
   399  		if dollar > 0 {
   400  			buffer.WriteString(expression[:dollar])
   401  		}
   402  		buffer.WriteString(v)
   403  		if last+1 < len(expression) {
   404  			buffer.WriteString(expression[last+1:])
   405  		}
   406  		expression = buffer.String()
   407  		dollar = strings.IndexRune(expression, '$')
   408  	}
   409  	return expression, nil
   410  }
   411  
   412  // NextArg provides extraction of the next argument from an arguments string passed to a Function. An empty string will
   413  // be returned if no argument remains.
   414  func NextArg(args string) (arg, remaining string) {
   415  	parens := 0
   416  	for i, ch := range args {
   417  		switch {
   418  		case ch == '(':
   419  			parens++
   420  		case ch == ')':
   421  			parens--
   422  		case ch == ',' && parens == 0:
   423  			return args[:i], args[i+1:]
   424  		default:
   425  		}
   426  	}
   427  	return args, ""
   428  }