github.com/rajeev159/opa@v0.45.0/topdown/arithmetic.go (about)

     1  // Copyright 2016 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package topdown
     6  
     7  import (
     8  	"math/big"
     9  
    10  	"fmt"
    11  
    12  	"github.com/open-policy-agent/opa/ast"
    13  	"github.com/open-policy-agent/opa/topdown/builtins"
    14  )
    15  
    16  type arithArity1 func(a *big.Float) (*big.Float, error)
    17  type arithArity2 func(a, b *big.Float) (*big.Float, error)
    18  
    19  func arithAbs(a *big.Float) (*big.Float, error) {
    20  	return a.Abs(a), nil
    21  }
    22  
    23  var halfAwayFromZero = big.NewFloat(0.5)
    24  
    25  func arithRound(a *big.Float) (*big.Float, error) {
    26  	var i *big.Int
    27  	if a.Signbit() {
    28  		i, _ = new(big.Float).Sub(a, halfAwayFromZero).Int(nil)
    29  	} else {
    30  		i, _ = new(big.Float).Add(a, halfAwayFromZero).Int(nil)
    31  	}
    32  	return new(big.Float).SetInt(i), nil
    33  }
    34  
    35  func arithCeil(a *big.Float) (*big.Float, error) {
    36  	i, _ := a.Int(nil)
    37  	f := new(big.Float).SetInt(i)
    38  
    39  	if f.Signbit() || a.Cmp(f) == 0 {
    40  		return f, nil
    41  	}
    42  
    43  	return new(big.Float).Add(f, big.NewFloat(1.0)), nil
    44  }
    45  
    46  func arithFloor(a *big.Float) (*big.Float, error) {
    47  	i, _ := a.Int(nil)
    48  	f := new(big.Float).SetInt(i)
    49  
    50  	if !f.Signbit() || a.Cmp(f) == 0 {
    51  		return f, nil
    52  	}
    53  
    54  	return new(big.Float).Sub(f, big.NewFloat(1.0)), nil
    55  }
    56  
    57  func arithPlus(a, b *big.Float) (*big.Float, error) {
    58  	return new(big.Float).Add(a, b), nil
    59  }
    60  
    61  func arithMinus(a, b *big.Float) (*big.Float, error) {
    62  	return new(big.Float).Sub(a, b), nil
    63  }
    64  
    65  func arithMultiply(a, b *big.Float) (*big.Float, error) {
    66  	return new(big.Float).Mul(a, b), nil
    67  }
    68  
    69  func arithDivide(a, b *big.Float) (*big.Float, error) {
    70  	i, acc := b.Int64()
    71  	if acc == big.Exact && i == 0 {
    72  		return nil, fmt.Errorf("divide by zero")
    73  	}
    74  	return new(big.Float).Quo(a, b), nil
    75  }
    76  
    77  func arithRem(a, b *big.Int) (*big.Int, error) {
    78  	if b.Int64() == 0 {
    79  		return nil, fmt.Errorf("modulo by zero")
    80  	}
    81  	return new(big.Int).Rem(a, b), nil
    82  }
    83  
    84  func builtinArithArity1(fn arithArity1) FunctionalBuiltin1 {
    85  	return func(a ast.Value) (ast.Value, error) {
    86  		n, err := builtins.NumberOperand(a, 1)
    87  		if err != nil {
    88  			return nil, err
    89  		}
    90  		f, err := fn(builtins.NumberToFloat(n))
    91  		if err != nil {
    92  			return nil, err
    93  		}
    94  		return builtins.FloatToNumber(f), nil
    95  	}
    96  }
    97  
    98  func builtinArithArity2(fn arithArity2) FunctionalBuiltin2 {
    99  	return func(a, b ast.Value) (ast.Value, error) {
   100  		n1, err := builtins.NumberOperand(a, 1)
   101  		if err != nil {
   102  			return nil, err
   103  		}
   104  		n2, err := builtins.NumberOperand(b, 2)
   105  		if err != nil {
   106  			return nil, err
   107  		}
   108  		f, err := fn(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		return builtins.FloatToNumber(f), nil
   113  	}
   114  }
   115  
   116  func builtinMinus(a, b ast.Value) (ast.Value, error) {
   117  
   118  	n1, ok1 := a.(ast.Number)
   119  	n2, ok2 := b.(ast.Number)
   120  
   121  	if ok1 && ok2 {
   122  		f, err := arithMinus(builtins.NumberToFloat(n1), builtins.NumberToFloat(n2))
   123  		if err != nil {
   124  			return nil, err
   125  		}
   126  		return builtins.FloatToNumber(f), nil
   127  	}
   128  
   129  	s1, ok3 := a.(ast.Set)
   130  	s2, ok4 := b.(ast.Set)
   131  
   132  	if ok3 && ok4 {
   133  		return s1.Diff(s2), nil
   134  	}
   135  
   136  	if !ok1 && !ok3 {
   137  		return nil, builtins.NewOperandTypeErr(1, a, "number", "set")
   138  	}
   139  
   140  	if ok2 {
   141  		return nil, builtins.NewOperandTypeErr(2, b, "set")
   142  	}
   143  
   144  	return nil, builtins.NewOperandTypeErr(2, b, "number")
   145  }
   146  
   147  func builtinRem(a, b ast.Value) (ast.Value, error) {
   148  	n1, ok1 := a.(ast.Number)
   149  	n2, ok2 := b.(ast.Number)
   150  
   151  	if ok1 && ok2 {
   152  
   153  		op1, err1 := builtins.NumberToInt(n1)
   154  		op2, err2 := builtins.NumberToInt(n2)
   155  
   156  		if err1 != nil || err2 != nil {
   157  			return nil, fmt.Errorf("modulo on floating-point number")
   158  		}
   159  
   160  		i, err := arithRem(op1, op2)
   161  		if err != nil {
   162  			return nil, err
   163  		}
   164  		return builtins.IntToNumber(i), nil
   165  	}
   166  
   167  	if !ok1 {
   168  		return nil, builtins.NewOperandTypeErr(1, a, "number")
   169  	}
   170  
   171  	return nil, builtins.NewOperandTypeErr(2, b, "number")
   172  }
   173  
   174  func init() {
   175  	RegisterFunctionalBuiltin1(ast.Abs.Name, builtinArithArity1(arithAbs))
   176  	RegisterFunctionalBuiltin1(ast.Round.Name, builtinArithArity1(arithRound))
   177  	RegisterFunctionalBuiltin1(ast.Ceil.Name, builtinArithArity1(arithCeil))
   178  	RegisterFunctionalBuiltin1(ast.Floor.Name, builtinArithArity1(arithFloor))
   179  	RegisterFunctionalBuiltin2(ast.Plus.Name, builtinArithArity2(arithPlus))
   180  	RegisterFunctionalBuiltin2(ast.Minus.Name, builtinMinus)
   181  	RegisterFunctionalBuiltin2(ast.Multiply.Name, builtinArithArity2(arithMultiply))
   182  	RegisterFunctionalBuiltin2(ast.Divide.Name, builtinArithArity2(arithDivide))
   183  	RegisterFunctionalBuiltin2(ast.Rem.Name, builtinRem)
   184  }