github.com/kubeshop/testkube@v1.17.23/pkg/tcl/expressionstcl/math.go (about)

     1  // Copyright 2024 Testkube.
     2  //
     3  // Licensed as a Testkube Pro file under the Testkube Community
     4  // License (the "License"); you may not use this file except in compliance with
     5  // the License. You may obtain a copy of the License at
     6  //
     7  //     https://github.com/kubeshop/testkube/blob/main/licenses/TCL.txt
     8  
     9  package expressionstcl
    10  
    11  import (
    12  	"errors"
    13  	"fmt"
    14  	"maps"
    15  	math2 "math"
    16  )
    17  
    18  type operator string
    19  
    20  const (
    21  	operatorEquals         operator = "="
    22  	operatorEqualsAlias    operator = "=="
    23  	operatorNotEquals      operator = "!="
    24  	operatorNotEqualsAlias operator = "<>"
    25  	operatorGt             operator = ">"
    26  	operatorGte            operator = ">="
    27  	operatorLt             operator = "<"
    28  	operatorLte            operator = "<="
    29  	operatorAnd            operator = "&&"
    30  	operatorOr             operator = "||"
    31  	operatorAdd            operator = "+"
    32  	operatorSubtract       operator = "-"
    33  	operatorModulo         operator = "%"
    34  	operatorDivide         operator = "/"
    35  	operatorMultiply       operator = "*"
    36  	operatorPower          operator = "**"
    37  )
    38  
    39  func getOperatorPriority(op operator) int {
    40  	switch op {
    41  	case operatorAnd, operatorOr:
    42  		return 0
    43  	case operatorEquals, operatorEqualsAlias, operatorNotEquals, operatorNotEqualsAlias,
    44  		operatorGt, operatorGte, operatorLt, operatorLte:
    45  		return 1
    46  	case operatorAdd, operatorSubtract:
    47  		return 2
    48  	case operatorMultiply, operatorDivide, operatorModulo:
    49  		return 3
    50  	case operatorPower:
    51  		return 4
    52  	}
    53  	panic("unknown operator: " + op)
    54  }
    55  
    56  type math struct {
    57  	operator operator
    58  	left     Expression
    59  	right    Expression
    60  }
    61  
    62  func newMath(operator operator, left Expression, right Expression) Expression {
    63  	if left == nil {
    64  		left = None
    65  	}
    66  	if right == nil {
    67  		right = None
    68  	}
    69  	return &math{operator: operator, left: left, right: right}
    70  }
    71  
    72  func runOp[T interface{}, U interface{}](v1 StaticValue, v2 StaticValue, mapper func(value StaticValue) (T, error), op func(s1, s2 T) U) (StaticValue, error) {
    73  	s1, err1 := mapper(v1)
    74  	if err1 != nil {
    75  		return nil, err1
    76  	}
    77  	s2, err2 := mapper(v2)
    78  	if err2 != nil {
    79  		return nil, err2
    80  	}
    81  	return NewValue(op(s1, s2)), nil
    82  }
    83  
    84  func staticString(v StaticValue) (string, error) {
    85  	return v.StringValue()
    86  }
    87  
    88  func staticFloat(v StaticValue) (float64, error) {
    89  	return v.FloatValue()
    90  }
    91  
    92  func staticBool(v StaticValue) (bool, error) {
    93  	return v.BoolValue()
    94  }
    95  
    96  func (s *math) performMath(v1 StaticValue, v2 StaticValue) (StaticValue, error) {
    97  	switch s.operator {
    98  	case operatorEquals, operatorEqualsAlias:
    99  		return runOp(v1, v2, staticString, func(s1, s2 string) bool {
   100  			return s1 == s2
   101  		})
   102  	case operatorNotEquals, operatorNotEqualsAlias:
   103  		return runOp(v1, v2, staticString, func(s1, s2 string) bool {
   104  			return s1 != s2
   105  		})
   106  	case operatorGt:
   107  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) bool {
   108  			return s1 > s2
   109  		})
   110  	case operatorLt:
   111  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) bool {
   112  			return s1 < s2
   113  		})
   114  	case operatorGte:
   115  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) bool {
   116  			return s1 >= s2
   117  		})
   118  	case operatorLte:
   119  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) bool {
   120  			return s1 <= s2
   121  		})
   122  	case operatorAnd:
   123  		return runOp(v1, v2, staticBool, func(s1, s2 bool) interface{} {
   124  			if s1 {
   125  				return v2.Value()
   126  			}
   127  			return v1.Value()
   128  		})
   129  	case operatorOr:
   130  		return runOp(v1, v2, staticBool, func(s1, s2 bool) interface{} {
   131  			if s1 {
   132  				return v1.Value()
   133  			}
   134  			return v2.Value()
   135  		})
   136  	case operatorAdd:
   137  		if v1.IsString() || v2.IsString() {
   138  			return runOp(v1, v2, staticString, func(s1, s2 string) string {
   139  				return s1 + s2
   140  			})
   141  		}
   142  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) float64 {
   143  			return s1 + s2
   144  		})
   145  	case operatorSubtract:
   146  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) float64 {
   147  			return s1 - s2
   148  		})
   149  	case operatorModulo:
   150  		divideByZero := false
   151  		res, err := runOp(v1, v2, staticFloat, func(s1, s2 float64) float64 {
   152  			if s2 == 0 {
   153  				divideByZero = true
   154  				return 0
   155  			}
   156  			return math2.Mod(s1, s2)
   157  		})
   158  		if divideByZero {
   159  			return nil, errors.New("cannot modulo by zero")
   160  		}
   161  		return res, err
   162  	case operatorDivide:
   163  		divideByZero := false
   164  		res, err := runOp(v1, v2, staticFloat, func(s1, s2 float64) float64 {
   165  			if s2 == 0 {
   166  				divideByZero = true
   167  				return 0
   168  			}
   169  			return s1 / s2
   170  		})
   171  		if divideByZero {
   172  			return nil, errors.New("cannot divide by zero")
   173  		}
   174  		return res, err
   175  	case operatorMultiply:
   176  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) float64 {
   177  			return s1 * s2
   178  		})
   179  	case operatorPower:
   180  		return runOp(v1, v2, staticFloat, func(s1, s2 float64) float64 {
   181  			return math2.Pow(s1, s2)
   182  		})
   183  	default:
   184  	}
   185  	return nil, fmt.Errorf("unknown math operator: %s", s.operator)
   186  }
   187  
   188  func (s *math) Type() Type {
   189  	l := s.left.Type()
   190  	r := s.right.Type()
   191  	switch s.operator {
   192  	case operatorAnd, operatorOr:
   193  		if l == r {
   194  			return l
   195  		}
   196  		return TypeUnknown
   197  	case operatorPower, operatorModulo, operatorSubtract, operatorMultiply, operatorDivide:
   198  		return TypeFloat64
   199  	case operatorAdd:
   200  		if l == TypeString || r == TypeString {
   201  			return TypeString
   202  		}
   203  		return TypeFloat64
   204  	case operatorEquals, operatorNotEquals, operatorEqualsAlias, operatorNotEqualsAlias, operatorGt, operatorLt, operatorGte, operatorLte:
   205  		return TypeBool
   206  	default:
   207  		return TypeUnknown
   208  	}
   209  }
   210  
   211  func (s *math) itemString(v Expression) string {
   212  	if vv, ok := v.(*math); ok {
   213  		if getOperatorPriority(vv.operator) >= getOperatorPriority(s.operator) && (vv.operator != operatorAdd || v.Type() == vv.left.Type()) {
   214  			return v.String()
   215  		}
   216  	}
   217  	return v.SafeString()
   218  }
   219  
   220  func (s *math) String() string {
   221  	return s.itemString(s.left) + string(s.operator) + s.itemString(s.right)
   222  }
   223  
   224  func (s *math) SafeString() string {
   225  	return "(" + s.String() + ")"
   226  }
   227  
   228  func (s *math) Template() string {
   229  	// Simplify the template when it is possible
   230  	if s.operator == operatorAdd && s.Type() == TypeString {
   231  		return s.left.Template() + s.right.Template()
   232  	}
   233  	return "{{" + s.String() + "}}"
   234  }
   235  
   236  func (s *math) SafeResolve(m ...Machine) (v Expression, changed bool, err error) {
   237  	var ch bool
   238  	s.left, ch, err = s.left.SafeResolve(m...)
   239  	changed = changed || ch
   240  	if err != nil {
   241  		return
   242  	}
   243  
   244  	// Fast track for cutting dead paths
   245  	if s.left.Static() != nil {
   246  		if s.operator == operatorAnd {
   247  			b, err := s.left.Static().BoolValue()
   248  			if err == nil && !b {
   249  				return s.left, true, nil
   250  			} else if err == nil {
   251  				return s.right, true, nil
   252  			}
   253  		} else if s.operator == operatorOr {
   254  			b, err := s.left.Static().BoolValue()
   255  			if err == nil && b {
   256  				return s.left, true, nil
   257  			} else if err == nil {
   258  				return s.right, true, nil
   259  			}
   260  		}
   261  	}
   262  
   263  	s.right, ch, err = s.right.SafeResolve(m...)
   264  	changed = changed || ch
   265  	if err != nil {
   266  		return
   267  	}
   268  
   269  	// Fast track for cutting dead paths
   270  	t := s.left.Type()
   271  	if s.left.Static() == nil && s.right.Static() != nil && t != TypeUnknown && t == s.right.Type() && t == TypeBool {
   272  		if s.operator == operatorAnd {
   273  			b, err := s.right.Static().BoolValue()
   274  			if err == nil && !b {
   275  				return s.right, true, nil
   276  			} else if err == nil {
   277  				return s.left, true, nil
   278  			}
   279  		} else if s.operator == operatorOr {
   280  			b, err := s.right.Static().BoolValue()
   281  			if err == nil && b {
   282  				return s.right, true, nil
   283  			} else if err == nil {
   284  				return s.left, true, nil
   285  			}
   286  		}
   287  	}
   288  
   289  	if s.left.Static() != nil && s.right.Static() != nil {
   290  		res, err := s.performMath(s.left.Static(), s.right.Static())
   291  		if err != nil {
   292  			return nil, changed, fmt.Errorf("error while performing math: %s: %s", s.String(), err)
   293  		}
   294  		return res, true, nil
   295  	}
   296  	return s, changed, nil
   297  }
   298  
   299  func (s *math) Resolve(m ...Machine) (v Expression, err error) {
   300  	return deepResolve(s, m...)
   301  }
   302  
   303  func (s *math) Static() StaticValue {
   304  	return nil
   305  }
   306  
   307  func (s *math) Accessors() map[string]struct{} {
   308  	result := make(map[string]struct{})
   309  	maps.Copy(result, s.left.Accessors())
   310  	maps.Copy(result, s.right.Accessors())
   311  	return result
   312  }
   313  
   314  func (s *math) Functions() map[string]struct{} {
   315  	result := make(map[string]struct{})
   316  	maps.Copy(result, s.left.Functions())
   317  	maps.Copy(result, s.right.Functions())
   318  	return result
   319  }